"platforms/vscode:/vscode.git/clone" did not exist on "3e8a62ba3f8dc21c76f851ca8ec024139e5a2683"
ExpressionUtilities.cpp 55.9 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2019 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
#include "openmm/common/ExpressionUtilities.h"
28
29
30
31
32
33
34
35
#include "openmm/OpenMMException.h"
#include "openmm/internal/SplineFitter.h"
#include "lepton/Operation.h"

using namespace OpenMM;
using namespace Lepton;
using namespace std;

36
ExpressionUtilities::ExpressionUtilities(ComputeContext& context) : context(context), fp1(1), fp2(2), fp3(3), periodicDistance(6) {
37
38
}

39
string ExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
40
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
41
42
43
    vector<pair<ExpressionTreeNode, string> > variableNodes;
    for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
        variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
44
    return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
45
46
}

47
string ExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
48
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
49
50
51
52
53
    stringstream out;
    vector<ParsedExpression> allExpressions;
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
        allExpressions.push_back(iter->second);
    vector<pair<ExpressionTreeNode, string> > temps = variables;
54
    vector<vector<double> > functionParams = computeFunctionParameters(functions);
55
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
peastman's avatar
peastman committed
56
        processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
57
58
59
60
61
        out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
    }
    return out.str();
}

62
void ExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
63
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<vector<double> >& functionParams,
peastman's avatar
peastman committed
64
        const vector<ParsedExpression>& allExpressions, const string& tempType) {
65
66
67
68
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return;
    for (int i = 0; i < (int) node.getChildren().size(); i++)
peastman's avatar
peastman committed
69
        processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
70
    string name = prefix+context.intToString(temps.size());
71
    bool hasRecordedNode = false;
72
    bool isVecType = (tempType[tempType.size()-1] == '3');
73

74
75
76
    out << tempType << " " << name << " = ";
    switch (node.getOperation().getId()) {
        case Operation::CONSTANT:
77
78
79
80
81
82
        {
            string value = context.doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
            if (isVecType)
                out << "make_" << tempType << "(" << value << ")";
            else
                out << value;
83
            break;
84
        }
85
86
87
88
        case Operation::VARIABLE:
            throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
        case Operation::CUSTOM:
        {
89
90
91
92
            if (isVecType)
                out << "make_" << tempType << "(0);\n";
            else
                out << "0;\n";
93
94
95
96
97
98
            temps.push_back(make_pair(node, name));
            hasRecordedNode = true;

            // If both the value and derivative of the function are needed, it's faster to calculate them both
            // at once, so check to see if both are needed.

99
            vector<const ExpressionTreeNode*> nodes;
100
            nodes.push_back(&node);
101
            for (int j = 0; j < (int) allExpressions.size(); j++)
102
                findRelatedCustomFunctions(node, allExpressions[j].getRootNode(), nodes);
103
104
105
            vector<string> nodeNames;
            nodeNames.push_back(name);
            for (int j = 1; j < (int) nodes.size(); j++) {
106
                string name2 = prefix+context.intToString(temps.size());
107
                out << tempType << " " << name2 << " = 0.0f;\n";
108
109
                nodeNames.push_back(name2);
                temps.push_back(make_pair(*nodes[j], name2));
110
111
            }
            out << "{\n";
112
113
114
            if (node.getOperation().getName() == "periodicdistance") {
                // This is the periodicdistance() function.

115
                out << tempType << "3 periodicDistance_delta = make_real3(";
116
117
118
119
                for (int i = 0; i < 3; i++) {
                    if (i > 0)
                        out << ", ";
                    out << getTempName(node.getChildren()[i], temps) << "-" << getTempName(node.getChildren()[i+3], temps);
120
                }
121
                out << ");\n";
122
                out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
123
124
                out << tempType << " periodicDistance_r2 = periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z;\n";
                out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_r2);\n";
125
126
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
127
128
129
130
131
132
133
                    int argIndex = -1;
                    for (int k = 0; k < 6; k++) {
                        if (derivOrder[k] > 0) {
                            if (derivOrder[k] > 1 || argIndex != -1)
                                throw OpenMMException("Unsupported derivative of periodicdistance"); // Should be impossible for this to happen.
                            argIndex = k;
                        }
134
                    }
135
                    if (argIndex == -1)
136
                        out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n";
137
                    else if (argIndex == 0)
138
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
139
                    else if (argIndex == 1)
140
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
141
                    else if (argIndex == 2)
142
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
143
                    else if (argIndex == 3)
144
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
145
                    else if (argIndex == 4)
146
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
147
                    else if (argIndex == 5)
148
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
149
150
                }
            }
151
152
153
154
155
156
157
158
            else if (node.getOperation().getName() == "dot") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    string child1 = getTempName(node.getChildren()[0], temps);
                    string child2 = getTempName(node.getChildren()[1], temps);
                    if (derivOrder[0] == 0 && derivOrder[1] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(dot(" << child1 << ", " << child2 << "));\n";
                    else
159
                        throw OpenMMException("Unsupported derivative order for dot()");
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
                }
            }
            else if (node.getOperation().getName() == "cross") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    string child1 = getTempName(node.getChildren()[0], temps);
                    string child2 = getTempName(node.getChildren()[1], temps);
                    if (derivOrder[0] == 0 && derivOrder[1] == 0)
                        out << nodeNames[j] << " = cross(" << child1 << ", " << child2 << ");\n";
                    else
                        throw OpenMMException("Unsupported derivative order for cross()");
                }
            }
            else if (node.getOperation().getName() == "vector") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                        out << nodeNames[j] << ".x = " << getTempName(node.getChildren()[0], temps) << ".x;\n";
                        out << nodeNames[j] << ".y = " << getTempName(node.getChildren()[1], temps) << ".y;\n";
                        out << nodeNames[j] << ".z = " << getTempName(node.getChildren()[2], temps) << ".z;\n";
                    }
                    else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0)
                        out << nodeNames[j] << ".x = 1;\n";
                    else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0)
                        out << nodeNames[j] << ".y = 1;\n";
                    else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1)
                        out << nodeNames[j] << ".z = 1;\n";
                }
            }
            else if (node.getOperation().getName() == "_x") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(" << getTempName(node.getChildren()[0], temps) << ".x);\n";
                    else
                        throw OpenMMException("Unsupported derivative order for _x()");
                }
            }
            else if (node.getOperation().getName() == "_y") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(" << getTempName(node.getChildren()[0], temps) << ".y);\n";
                    else
                        throw OpenMMException("Unsupported derivative order for _y()");
                }
            }
            else if (node.getOperation().getName() == "_z") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(" << getTempName(node.getChildren()[0], temps) << ".z);\n";
                    else
                        throw OpenMMException("Unsupported derivative order for _z()");
                }
            }
216
217
            else {
                // This is a tabulated function.
218

219
220
221
222
223
224
225
226
227
228
                int i;
                for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
                    ;
                if (i == functionNames.size())
                    throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
                vector<string> paramsFloat, paramsInt;
                for (int j = 0; j < (int) functionParams[i].size(); j++) {
                    paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
                    paramsInt.push_back(context.intToString((int) functionParams[i][j]));
                }
229
230
231
232
233
                vector<string> suffixes;
                if (isVecType) {
                    suffixes.push_back(".x");
                    suffixes.push_back(".y");
                    suffixes.push_back(".z");
234
                }
235
236
                else {
                    suffixes.push_back("");
237
                }
238
239
240
                for (auto& suffix : suffixes) {
                    out << "{\n";
                    if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
241
                        int periodic = functionParams[i][4];
242
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
243
244
245
246
247
248
249
250
                        if (periodic) {
                            out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[5]<< ";\n";
                            out << "x = (x - floor(x))*" << paramsFloat[6] << ";\n";
                        }
                        else {
                            out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
                            out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
                        }
251
252
                        out << "int index = (int) (floor(x));\n";
                        out << "index = min(index, (int) " << paramsInt[3] << ");\n";
253
254
255
256
257
258
259
260
261
                        out << "float4 coeff = " << functionNames[i].second << "[index];\n";
                        out << "real b = x-index;\n";
                        out << "real a = 1.0f-b;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0)
                                out << nodeNames[j] << suffix << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
                            else
                                out << nodeNames[j] << suffix << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
262
                        }
263
264
265
                        if (!periodic)
                            out << "}\n";
                      }
266
                    else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
267
                        int periodic = functionParams[i][8];
268
269
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                        out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n";
270
271
272
273
274
275
276
277
278
279
280
                        if (periodic) {
                            out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[9] << ";\n";
                            out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[10] << ";\n";
                            out << "x = (x - floor(x))*" << paramsFloat[0] << ";\n";
                            out << "y = (y - floor(y))*" << paramsFloat[1] << ";\n";
                        }
                        else {
                            out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
                            out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
                            out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
                        }
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
                        out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
                        out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
                        out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
                        out << "float4 c[4];\n";
                        for (int j = 0; j < 4; j++)
                            out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
                        out << "real da = x-s;\n";
                        out << "real db = y-t;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0) {
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[3].w*db + c[3].z)*db + c[3].y)*db + c[3].x;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[2].w*db + c[2].z)*db + c[2].y)*db + c[2].x;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[1].w*db + c[1].z)*db + c[1].y)*db + c[1].x;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[0].w*db + c[0].z)*db + c[0].y)*db + c[0].x;\n";
                            }
                            else if (derivOrder[0] == 1 && derivOrder[1] == 0) {
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].w*da + 2.0f*c[2].w)*da + c[1].w;\n";
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].z*da + 2.0f*c[2].z)*da + c[1].z;\n";
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].y*da + 2.0f*c[2].y)*da + c[1].y;\n";
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].x*da + 2.0f*c[2].x)*da + c[1].x;\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[6] << ";\n";
                            }
                            else if (derivOrder[0] == 0 && derivOrder[1] == 1) {
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[3].w*db + 2.0f*c[3].z)*db + c[3].y;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[2].w*db + 2.0f*c[2].z)*db + c[2].y;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[1].w*db + 2.0f*c[1].z)*db + c[1].y;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[0].w*db + 2.0f*c[0].z)*db + c[0].y;\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[7] << ";\n";
                            }
                            else
                                throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
313
                        }
314
315
                        if (!periodic)
                            out << "}\n";
316
317
                    }
                    else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
318
                        int periodic = functionParams[i][12];
319
320
321
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                        out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n";
                        out << "real z = " << getTempName(node.getChildren()[2], temps) << suffix << ";\n";
322
323
324
325
326
327
328
329
330
331
332
333
334
335
                        if (periodic) {
                            out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[13] << ";\n";
                            out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[14] << ";\n";
                            out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[15] << ";\n";
                            out << "x = (x - floor(x))*" << paramsFloat[0] << ";\n";
                            out << "y = (y - floor(y))*" << paramsFloat[1] << ";\n";
                            out << "z = (z - floor(z))*" << paramsFloat[2] << ";\n";
                        }
                        else {
                            out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
                            out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
                            out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
                            out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n";
                        }
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
                        out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
                        out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
                        out << "int u = min((int) floor(z), " << paramsInt[2] << "-1);\n";
                        out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
                        out << "float4 c[16];\n";
                        for (int j = 0; j < 16; j++)
                            out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
                        out << "real da = x-s;\n";
                        out << "real db = y-t;\n";
                        out << "real dc = z-u;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                                out << "real value[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = k + 4*m;
                                        out << "value[" << m << "] = db*value[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
                                    }
                                out << nodeNames[j] << suffix << " = value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));\n";
                            }
                            else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                                out << "real derivx[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = k + 4*m;
                                        out << "derivx[" << m << "] = db*derivx[" << m << "] + (3*c[" << base << "].w*da + 2*c[" << base << "].z)*da + c[" << base << "].y;\n";
                                    }
                                out << nodeNames[j] << suffix << " = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[9] << ";\n";
                            }
                            else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0) {
                                const string suffixes[] = {".x", ".y", ".z", ".w"};
                                out << "real derivy[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = 4*m;
                                        string suffix = suffixes[k];
                                        out << "derivy[" << m << "] = da*derivy[" << m << "] + (3*c[" << (base+3) << "]" << suffix << "*db + 2*c[" << (base+2) << "]" << suffix << ")*db + c[" << (base+1) << "]" << suffix << ";\n";
                                    }
                                out << nodeNames[j] << suffix << " = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[10] << ";\n";
                            }
                            else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1) {
                                out << "real derivz[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = k + 4*m;
                                        out << "derivz[" << m << "] = db*derivz[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
                                    }
                                out << nodeNames[j] << suffix << " = derivz[1] + dc*(2*derivz[2] + dc*3*derivz[3]);\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[11] << ";\n";
                            }
                            else
                                throw OpenMMException("Unsupported derivative order for Continuous3DFunction");
391
                        }
392
393
                        if (!periodic)
                            out << "}\n";
394
                    }
395
396
397
398
399
400
401
402
403
404
                    else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0) {
                                out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                                out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
                                out << "int index = (int) floor(x+0.5f);\n";
                                out << nodeNames[j] << suffix << " = " << functionNames[i].second << "[index];\n";
                                out << "}\n";
                            }
405
                        }
406
                    }
407
408
409
410
411
412
413
414
415
416
417
418
                    else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0) {
                                out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << suffix << "+0.5f);\n";
                                out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << suffix << "+0.5f);\n";
                                out << "int xsize = (int) " << paramsInt[0] << ";\n";
                                out << "int ysize = (int) " << paramsInt[1] << ";\n";
                                out << "int index = x+y*xsize;\n";
                                out << "if (index >= 0 && index < xsize*ysize)\n";
                                out << nodeNames[j] << suffix << " = " << functionNames[i].second << "[index];\n";
                            }
419
                        }
420
                    }
421
422
423
424
425
426
427
428
429
430
431
432
433
434
                    else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                                out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << suffix << "+0.5f);\n";
                                out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << suffix << "+0.5f);\n";
                                out << "int z = (int) floor(" << getTempName(node.getChildren()[2], temps) << suffix << "+0.5f);\n";
                                out << "int xsize = (int) " << paramsInt[0] << ";\n";
                                out << "int ysize = (int) " << paramsInt[1] << ";\n";
                                out << "int zsize = (int) " << paramsInt[2] << ";\n";
                                out << "int index = x+(y+z*ysize)*xsize;\n";
                                out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
                                out << nodeNames[j] << suffix << " = " << functionNames[i].second << "[index];\n";
                            }
435
                        }
436
                    }
437
                    out << "}\n";
peastman's avatar
peastman committed
438
439
                }
            }
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
            out << "}";
            break;
        }
        case Operation::ADD:
            out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps);
            break;
        case Operation::SUBTRACT:
            out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps);
            break;
        case Operation::MULTIPLY:
            out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
            break;
        case Operation::DIVIDE:
        {
            bool haveReciprocal = false;
455
456
457
458
459
460
461
462
463
464
465
466
467
            if (node.getChildren()[1].getOperation().getId() == Operation::RECIPROCAL) {
                for (int i = 0; i < (int) temps.size(); i++)
                    if (temps[i].first == node.getChildren()[1].getChildren()[1]) {
                        haveReciprocal = true;
                        out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second;
                    }
            }
            if (!haveReciprocal)
                for (int i = 0; i < (int) temps.size(); i++)
                    if (temps[i].first.getOperation().getId() == Operation::RECIPROCAL && temps[i].first.getChildren()[0] == node.getChildren()[1]) {
                        haveReciprocal = true;
                        out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second;
                    }
468
469
470
471
472
            if (!haveReciprocal)
                out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
            break;
        }
        case Operation::POWER:
473
            out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
474
475
476
477
478
            break;
        case Operation::NEGATE:
            out << "-" << getTempName(node.getChildren()[0], temps);
            break;
        case Operation::SQRT:
479
            callFunction(out, "sqrtf", "sqrt", getTempName(node.getChildren()[0], temps), tempType);
480
481
            break;
        case Operation::EXP:
482
            callFunction(out, "expf", "exp", getTempName(node.getChildren()[0], temps), tempType);
483
484
            break;
        case Operation::LOG:
485
            callFunction(out, "logf", "log", getTempName(node.getChildren()[0], temps), tempType);
486
487
            break;
        case Operation::SIN:
488
            callFunction(out, "sinf", "sin", getTempName(node.getChildren()[0], temps), tempType);
489
490
            break;
        case Operation::COS:
491
            callFunction(out, "cosf", "cos", getTempName(node.getChildren()[0], temps), tempType);
492
493
            break;
        case Operation::SEC:
494
495
            out << "1/";
            callFunction(out, "cosf", "cos", getTempName(node.getChildren()[0], temps), tempType);
496
497
            break;
        case Operation::CSC:
498
499
            out << "1/";
            callFunction(out, "sinf", "sin", getTempName(node.getChildren()[0], temps), tempType);
500
501
            break;
        case Operation::TAN:
502
            callFunction(out, "tanf", "tan", getTempName(node.getChildren()[0], temps), tempType);
503
504
            break;
        case Operation::COT:
505
506
            out << "1/";
            callFunction(out, "tanf", "tan", getTempName(node.getChildren()[0], temps), tempType);
507
508
            break;
        case Operation::ASIN:
509
            callFunction(out, "asinf", "asin", getTempName(node.getChildren()[0], temps), tempType);
510
511
            break;
        case Operation::ACOS:
512
            callFunction(out, "acosf", "acos", getTempName(node.getChildren()[0], temps), tempType);
513
514
            break;
        case Operation::ATAN:
515
            callFunction(out, "atanf", "atan", getTempName(node.getChildren()[0], temps), tempType);
516
            break;
517
        case Operation::ATAN2:
Peter Eastman's avatar
Peter Eastman committed
518
            callFunction2(out, "atan2f", "atan2", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
519
            break;
520
        case Operation::SINH:
521
            callFunction(out, "sinh", "sinh", getTempName(node.getChildren()[0], temps), tempType);
522
523
            break;
        case Operation::COSH:
524
            callFunction(out, "cosh", "cosh", getTempName(node.getChildren()[0], temps), tempType);
525
526
            break;
        case Operation::TANH:
527
            callFunction(out, "tanh", "tanh", getTempName(node.getChildren()[0], temps), tempType);
528
529
            break;
        case Operation::ERF:
530
            callFunction(out, "erf", "erf", getTempName(node.getChildren()[0], temps), tempType);
531
532
            break;
        case Operation::ERFC:
533
            callFunction(out, "erfc", "erfc", getTempName(node.getChildren()[0], temps), tempType);
534
535
            break;
        case Operation::STEP:
536
537
538
539
540
541
542
543
544
545
546
547
548
        {
            string compareVal = getTempName(node.getChildren()[0], temps);
            if (isVecType) {
                out << "make_" << tempType << "(0);\n";
                out << "{\n";
                out << tempType<<" tempCompareValue = " << compareVal << ";\n";
                out << name << ".x = (tempCompareValue.x >= 0 ? 1 : 0);\n";
                out << name << ".y = (tempCompareValue.y >= 0 ? 1 : 0);\n";
                out << name << ".z = (tempCompareValue.z >= 0 ? 1 : 0);\n";
                out << "}\n";
            }
            else
                out << compareVal << " >= 0 ? 1 : 0";
549
            break;
550
        }
551
        case Operation::DELTA:
552
553
554
555
556
557
558
559
560
561
562
563
564
        {
            string compareVal = getTempName(node.getChildren()[0], temps);
            if (isVecType) {
                out << "make_" << tempType << "(0);\n";
                out << "{\n";
                out << tempType<<" tempCompareValue = " << compareVal << ";\n";
                out << name << ".x = (tempCompareValue.x == 0 ? 1 : 0);\n";
                out << name << ".y = (tempCompareValue.y == 0 ? 1 : 0);\n";
                out << name << ".z = (tempCompareValue.z == 0 ? 1 : 0);\n";
                out << "}\n";
            }
            else
                out << compareVal << " == 0 ? 1 : 0";
565
            break;
566
        }
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        case Operation::SQUARE:
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg;
            break;
        }
        case Operation::CUBE:
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg << "*" << arg;
            break;
        }
        case Operation::RECIPROCAL:
            out << "RECIP(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::ADD_CONSTANT:
583
584
585
586
587
588
589
590
591
592
            if (isVecType) {
                string val = context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue());
                string arg = getTempName(node.getChildren()[0], temps);
                out << "make_" << tempType << "(";
                out << val << "+" << arg << ".x, ";
                out << val << "+" << arg << ".y, ";
                out << val << "+" << arg << ".z)";
            }
            else
                out << context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
593
594
            break;
        case Operation::MULTIPLY_CONSTANT:
595
            out << context.doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps);
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
            break;
        case Operation::POWER_CONSTANT:
        {
            double exponent = dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue();
            if (exponent == 0.0)
                out << "1.0f";
            else if (exponent == (int) exponent) {
                out << "0.0f;\n";
                temps.push_back(make_pair(node, name));
                hasRecordedNode = true;

                // If multiple integral powers of the same base are needed, it's faster to calculate all of them
                // at once, so check to see if others are also needed.

                map<int, const ExpressionTreeNode*> powers;
                powers[(int) exponent] = &node;
                for (int j = 0; j < (int) allExpressions.size(); j++)
                    findRelatedPowers(node, allExpressions[j].getRootNode(), powers);
                vector<int> exponents;
                vector<string> names;
                vector<bool> hasAssigned(powers.size(), false);
                exponents.push_back((int) fabs(exponent));
                names.push_back(name);
peastman's avatar
peastman committed
619
620
621
                for (auto& power : powers) {
                    if (power.first != exponent) {
                        exponents.push_back(power.first >= 0 ? power.first : -power.first);
622
                        string name2 = prefix+context.intToString(temps.size());
623
                        names.push_back(name2);
peastman's avatar
peastman committed
624
                        temps.push_back(make_pair(*power.second, name2));
625
626
627
628
                        out << tempType << " " << name2 << " = 0.0f;\n";
                    }
                }
                out << "{\n";
629
                out << "real multiplier = " << (exponent < 0.0 ? "RECIP(" : "(") << getTempName(node.getChildren()[0], temps) << ");\n";
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
                bool done = false;
                while (!done) {
                    done = true;
                    for (int i = 0; i < (int) exponents.size(); i++) {
                        if (exponents[i]%2 == 1) {
                            if (!hasAssigned[i])
                                out << names[i] << " = multiplier;\n";
                            else
                                out << names[i] << " *= multiplier;\n";
                            hasAssigned[i] = true;
                        }
                        exponents[i] >>= 1;
                        if (exponents[i] != 0)
                            done = false;
                    }
                    if (!done)
                        out << "multiplier *= multiplier;\n";
                }
                out << "}";
            }
            else
651
                out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << context.doubleToString(exponent) << ")";
652
653
654
            break;
        }
        case Operation::MIN:
655
            callFunction2(out, "min", "min", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
656
657
            break;
        case Operation::MAX:
658
            callFunction2(out, "max", "max", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
659
660
            break;
        case Operation::ABS:
661
            callFunction(out, "fabs", "fabs", getTempName(node.getChildren()[0], temps), tempType);
662
            break;
663
        case Operation::FLOOR:
664
            callFunction(out, "floor", "floor", getTempName(node.getChildren()[0], temps), tempType);
665
666
            break;
        case Operation::CEIL:
667
            callFunction(out, "ceil", "ceil", getTempName(node.getChildren()[0], temps), tempType);
668
            break;
669
        case Operation::SELECT:
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
        {
            string compareVal = getTempName(node.getChildren()[0], temps);
            string val1 = getTempName(node.getChildren()[1], temps);
            string val2 = getTempName(node.getChildren()[2], temps);
            if (isVecType) {
                out << "make_" << tempType << "(0);\n";
                out << "{\n";
                out << tempType<<" tempCompareValue = " << compareVal << ";\n";
                out << name << ".x = (tempCompareValue.x != 0 ? " << val1 << ".x : " << val2 << ".x);\n";
                out << name << ".y = (tempCompareValue.y != 0 ? " << val1 << ".y : " << val2 << ".y);\n";
                out << name << ".z = (tempCompareValue.z != 0 ? " << val1 << ".z : " << val2 << ".z);\n";
                out << "}\n";
            }
            else
                out << "(" << compareVal << " != 0 ? " << val1 << " : " << val2 << ")";
685
            break;
686
        }
687
688
689
690
691
692
693
694
        default:
            throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
    }
    out << ";\n";
    if (!hasRecordedNode)
        temps.push_back(make_pair(node, name));
}

695
string ExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector<pair<ExpressionTreeNode, string> >& temps) {
696
697
698
699
700
701
702
703
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return temps[i].second;
    stringstream out;
    out << "Internal error: No temporary variable for expression node: " << node;
    throw OpenMMException(out.str());
}

704
void ExpressionUtilities::findRelatedCustomFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
705
            vector<const Lepton::ExpressionTreeNode*>& nodes) {
706
707
    if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
        // Make sure the arguments are identical.
708

709
710
711
        for (int i = 0; i < (int) node.getChildren().size(); i++)
            if (node.getChildren()[i] != searchNode.getChildren()[i])
                return;
712

713
        // See if we already have an identical node.
714

715
716
717
        for (int i = 0; i < (int) nodes.size(); i++)
            if (*nodes[i] == searchNode)
                return;
718

719
        // Add the node.
720

721
        nodes.push_back(&searchNode);
722
    }
723
724
    else
        for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
725
            findRelatedCustomFunctions(node, searchNode.getChildren()[i], nodes);
726
727
}

728
void ExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
    if (searchNode.getOperation().getId() == Operation::POWER_CONSTANT && node.getChildren()[0] == searchNode.getChildren()[0]) {
        double realPower = dynamic_cast<const Operation::PowerConstant*>(&searchNode.getOperation())->getValue();
        int power = (int) realPower;
        if (power != realPower)
            return; // We are only interested in integer powers.
        if (powers.find(power) != powers.end())
            return; // This power is already in the map.
        if (powers.begin()->first*power < 0)
            return; // All powers must have the same sign.
        powers[power] = &searchNode;
    }
    else
        for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
            findRelatedPowers(node, searchNode.getChildren()[i], powers);
}

745
vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) {
746
    if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
peastman's avatar
peastman committed
747
748
        // Compute the spline coefficients.

749
750
751
752
        const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function);
        vector<double> values;
        double min, max;
        fn.getFunctionParameters(values, min, max);
753
        bool periodic = fn.getPeriodic();
754
755
756
757
        int numValues = values.size();
        vector<double> x(numValues), derivs;
        for (int i = 0; i < numValues; i++)
            x[i] = min+i*(max-min)/(numValues-1);
758
        SplineFitter::createSpline(x, values, periodic, derivs);
759
760
761
762
763
764
765
        vector<float> f(4*(numValues-1));
        for (int i = 0; i < (int) values.size()-1; i++) {
            f[4*i] = (float) values[i];
            f[4*i+1] = (float) values[i+1];
            f[4*i+2] = (float) (derivs[i]/6.0);
            f[4*i+3] = (float) (derivs[i+1]/6.0);
        }
peastman's avatar
peastman committed
766
767
768
        width = 4;
        return f;
    }
769
770
771
772
773
774
775
776
    if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL) {
        // Compute the spline coefficients.

        const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(function);
        vector<double> values;
        int xsize, ysize;
        double xmin, xmax, ymin, ymax;
        fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
777
        bool periodic = fn.getPeriodic();
778
779
780
781
782
783
        vector<double> x(xsize), y(ysize);
        for (int i = 0; i < xsize; i++)
            x[i] = xmin+i*(xmax-xmin)/(xsize-1);
        for (int i = 0; i < ysize; i++)
            y[i] = ymin+i*(ymax-ymin)/(ysize-1);
        vector<vector<double> > c;
784
        SplineFitter::create2DSpline(x, y, values, periodic, c);
785
786
787
788
789
790
791
792
        vector<float> f(16*c.size());
        for (int i = 0; i < (int) c.size(); i++) {
            for (int j = 0; j < 16; j++)
                f[16*i+j] = (float) c[i][j];
        }
        width = 4;
        return f;
    }
793
794
795
796
797
798
799
800
    if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL) {
        // Compute the spline coefficients.

        const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(function);
        vector<double> values;
        int xsize, ysize, zsize;
        double xmin, xmax, ymin, ymax, zmin, zmax;
        fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
801
        bool periodic = fn.getPeriodic();
802
803
804
805
806
807
808
809
        vector<double> x(xsize), y(ysize), z(zsize);
        for (int i = 0; i < xsize; i++)
            x[i] = xmin+i*(xmax-xmin)/(xsize-1);
        for (int i = 0; i < ysize; i++)
            y[i] = ymin+i*(ymax-ymin)/(ysize-1);
        for (int i = 0; i < zsize; i++)
            z[i] = zmin+i*(zmax-zmin)/(zsize-1);
        vector<vector<double> > c;
810
        SplineFitter::create3DSpline(x, y, z, values, periodic, c);
811
812
813
814
815
816
817
818
        vector<float> f(64*c.size());
        for (int i = 0; i < (int) c.size(); i++) {
            for (int j = 0; j < 64; j++)
                f[64*i+j] = (float) c[i][j];
        }
        width = 4;
        return f;
    }
peastman's avatar
peastman committed
819
820
    if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL) {
        // Record the tabulated values.
821

peastman's avatar
peastman committed
822
823
824
825
826
827
828
829
        const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(function);
        vector<double> values;
        fn.getFunctionParameters(values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
830
831
        return f;
    }
832
833
    if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL) {
        // Record the tabulated values.
834

835
836
837
838
839
840
841
842
843
844
845
846
847
        const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(function);
        int xsize, ysize;
        vector<double> values;
        fn.getFunctionParameters(xsize, ysize, values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
        return f;
    }
    if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL) {
        // Record the tabulated values.
848

849
850
851
852
853
854
855
856
857
858
859
        const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(function);
        int xsize, ysize, zsize;
        vector<double> values;
        fn.getFunctionParameters(xsize, ysize, zsize, values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
        return f;
    }
860
861
862
    throw OpenMMException("computeFunctionCoefficients: Unknown function type");
}

863
vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
864
    vector<vector<double> > params(functions.size());
865
866
867
868
869
870
    for (int i = 0; i < (int) functions.size(); i++) {
        if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
            const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]);
            vector<double> values;
            double min, max;
            fn.getFunctionParameters(values, min, max);
871
            int periodic = (int) fn.getPeriodic();
872
873
874
875
            params[i].push_back(min);
            params[i].push_back(max);
            params[i].push_back((values.size()-1)/(max-min));
            params[i].push_back(values.size()-2);
876
877
878
            params[i].push_back(periodic);
            params[i].push_back(1.0/(max-min));
            params[i].push_back(values.size()-1);
879
        }
880
881
882
883
884
885
        else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
            const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(*functions[i]);
            vector<double> values;
            int xsize, ysize;
            double xmin, xmax, ymin, ymax;
            fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
886
            int periodic = (int) fn.getPeriodic();
887
888
889
890
891
892
893
894
            params[i].push_back(xsize-1);
            params[i].push_back(ysize-1);
            params[i].push_back(xmin);
            params[i].push_back(xmax);
            params[i].push_back(ymin);
            params[i].push_back(ymax);
            params[i].push_back((xsize-1)/(xmax-xmin));
            params[i].push_back((ysize-1)/(ymax-ymin));
895
896
897
            params[i].push_back(periodic);
            params[i].push_back(1.0/(xmax-xmin));
            params[i].push_back(1.0/(ymax-ymin));
898
        }
899
900
901
902
903
904
        else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
            const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(*functions[i]);
            vector<double> values;
            int xsize, ysize, zsize;
            double xmin, xmax, ymin, ymax, zmin, zmax;
            fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
905
            int periodic = (int) fn.getPeriodic();
906
907
908
909
910
911
912
913
914
915
916
917
            params[i].push_back(xsize-1);
            params[i].push_back(ysize-1);
            params[i].push_back(zsize-1);
            params[i].push_back(xmin);
            params[i].push_back(xmax);
            params[i].push_back(ymin);
            params[i].push_back(ymax);
            params[i].push_back(zmin);
            params[i].push_back(zmax);
            params[i].push_back((xsize-1)/(xmax-xmin));
            params[i].push_back((ysize-1)/(ymax-ymin));
            params[i].push_back((zsize-1)/(zmax-zmin));
918
919
920
921
            params[i].push_back(periodic);
            params[i].push_back(1.0/(xmax-xmin));
            params[i].push_back(1.0/(ymax-ymin));
            params[i].push_back(1.0/(zmax-zmin));
922
        }
peastman's avatar
peastman committed
923
924
925
926
        else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
            const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
            vector<double> values;
            fn.getFunctionParameters(values);
927
            params[i].push_back(values.size());
peastman's avatar
peastman committed
928
        }
929
930
931
932
933
        else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
            const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
            int xsize, ysize;
            vector<double> values;
            fn.getFunctionParameters(xsize, ysize, values);
934
935
            params[i].push_back(xsize);
            params[i].push_back(ysize);
936
937
938
939
940
941
        }
        else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
            const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
            int xsize, ysize, zsize;
            vector<double> values;
            fn.getFunctionParameters(xsize, ysize, zsize, values);
942
943
944
            params[i].push_back(xsize);
            params[i].push_back(ysize);
            params[i].push_back(zsize);
945
        }
946
947
948
949
        else
            throw OpenMMException("computeFunctionParameters: Unknown function type");
    }
    return params;
950
}
951

952
Lepton::CustomFunction* ExpressionUtilities::getFunctionPlaceholder(const TabulatedFunction& function) {
953
954
    if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
        return &fp1;
955
956
    if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL)
        return &fp2;
957
958
    if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL)
        return &fp3;
959
960
961
962
963
964
965
966
    if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
        return &fp1;
    if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL)
        return &fp2;
    if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL)
        return &fp3;
    throw OpenMMException("getFunctionPlaceholder: Unknown function type");
}
967

968
Lepton::CustomFunction* ExpressionUtilities::getPeriodicDistancePlaceholder() {
969
    return &periodicDistance;
970
}
971

972
void ExpressionUtilities::callFunction(stringstream& out, string singleFn, string doubleFn, const string& arg, const string& tempType) {
973
974
    bool isDouble = (tempType[0] == 'd');
    bool isVector = (tempType[tempType.size()-1] == '3');
975
976
977
978
979
    string fn = (isDouble ? doubleFn : singleFn);
    if (isVector)
        out<<"make_"<<tempType<<"("<<fn<<"("<<arg<<".x), "<<fn<<"("<<arg<<".y), "<<fn<<"("<<arg<<".z))";
    else
        out<<fn<<"("<<arg<<")";
980
}
981

982
void ExpressionUtilities::callFunction2(stringstream& out, string singleFn, string doubleFn, const string& arg1, const string& arg2, const string& tempType) {
983
    bool isDouble = (tempType[0] == 'd');
984
    bool isVector = (tempType[tempType.size()-1] == '3');
985
    string fn = (isDouble ? doubleFn : singleFn);
986
987
988
989
990
991
992
993
994
    if (isVector) {
        out<<"make_"<<tempType<<"(";
        out<<fn<<"("<<arg1<<".x, "<<arg2<<".x), ";
        out<<fn<<"("<<arg1<<".y, "<<arg2<<".y), ";
        out<<fn<<"("<<arg1<<".z, "<<arg2<<".z))";
    }
    else
        out<<fn<<"(("<<tempType<<") "<<arg1<<", ("<<tempType<<") "<<arg2<<")";
}