ExpressionUtilities.cpp 54.3 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2019 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
#include "openmm/common/ExpressionUtilities.h"
28
29
30
31
32
33
34
35
#include "openmm/OpenMMException.h"
#include "openmm/internal/SplineFitter.h"
#include "lepton/Operation.h"

using namespace OpenMM;
using namespace Lepton;
using namespace std;

36
ExpressionUtilities::ExpressionUtilities(ComputeContext& context) : context(context), fp1(1), fp2(2), fp3(3), periodicDistance(6) {
37
38
}

39
string ExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
40
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
41
42
43
    vector<pair<ExpressionTreeNode, string> > variableNodes;
    for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
        variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
44
    return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
45
46
}

47
string ExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
48
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
49
50
51
52
53
    stringstream out;
    vector<ParsedExpression> allExpressions;
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
        allExpressions.push_back(iter->second);
    vector<pair<ExpressionTreeNode, string> > temps = variables;
54
    vector<vector<double> > functionParams = computeFunctionParameters(functions);
55
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
peastman's avatar
peastman committed
56
        processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
57
58
59
60
61
        out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
    }
    return out.str();
}

62
void ExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
63
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<vector<double> >& functionParams,
peastman's avatar
peastman committed
64
        const vector<ParsedExpression>& allExpressions, const string& tempType) {
65
66
67
68
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return;
    for (int i = 0; i < (int) node.getChildren().size(); i++)
peastman's avatar
peastman committed
69
        processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
70
    string name = prefix+context.intToString(temps.size());
71
    bool hasRecordedNode = false;
72
    bool isVecType = (tempType[tempType.size()-1] == '3');
73

74
75
76
    out << tempType << " " << name << " = ";
    switch (node.getOperation().getId()) {
        case Operation::CONSTANT:
77
78
79
80
81
82
        {
            string value = context.doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
            if (isVecType)
                out << "make_" << tempType << "(" << value << ")";
            else
                out << value;
83
            break;
84
        }
85
86
87
88
        case Operation::VARIABLE:
            throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
        case Operation::CUSTOM:
        {
89
90
91
92
            if (isVecType)
                out << "make_" << tempType << "(0);\n";
            else
                out << "0;\n";
93
94
95
96
97
98
            temps.push_back(make_pair(node, name));
            hasRecordedNode = true;

            // If both the value and derivative of the function are needed, it's faster to calculate them both
            // at once, so check to see if both are needed.

99
            vector<const ExpressionTreeNode*> nodes;
100
            nodes.push_back(&node);
101
            for (int j = 0; j < (int) allExpressions.size(); j++)
102
                findRelatedCustomFunctions(node, allExpressions[j].getRootNode(), nodes);
103
104
105
            vector<string> nodeNames;
            nodeNames.push_back(name);
            for (int j = 1; j < (int) nodes.size(); j++) {
106
                string name2 = prefix+context.intToString(temps.size());
107
                out << tempType << " " << name2 << " = 0.0f;\n";
108
109
                nodeNames.push_back(name2);
                temps.push_back(make_pair(*nodes[j], name2));
110
111
            }
            out << "{\n";
112
113
114
            if (node.getOperation().getName() == "periodicdistance") {
                // This is the periodicdistance() function.

115
                out << tempType << "3 periodicDistance_delta = make_real3(";
116
117
118
119
                for (int i = 0; i < 3; i++) {
                    if (i > 0)
                        out << ", ";
                    out << getTempName(node.getChildren()[i], temps) << "-" << getTempName(node.getChildren()[i+3], temps);
120
                }
121
                out << ");\n";
122
                out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
123
124
                out << tempType << " periodicDistance_r2 = periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z;\n";
                out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_r2);\n";
125
126
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
127
128
129
130
131
132
133
                    int argIndex = -1;
                    for (int k = 0; k < 6; k++) {
                        if (derivOrder[k] > 0) {
                            if (derivOrder[k] > 1 || argIndex != -1)
                                throw OpenMMException("Unsupported derivative of periodicdistance"); // Should be impossible for this to happen.
                            argIndex = k;
                        }
134
                    }
135
                    if (argIndex == -1)
136
                        out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n";
137
                    else if (argIndex == 0)
138
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
139
                    else if (argIndex == 1)
140
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
141
                    else if (argIndex == 2)
142
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
143
                    else if (argIndex == 3)
144
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
145
                    else if (argIndex == 4)
146
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
147
                    else if (argIndex == 5)
148
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
149
150
                }
            }
151
152
153
154
155
156
157
158
            else if (node.getOperation().getName() == "dot") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    string child1 = getTempName(node.getChildren()[0], temps);
                    string child2 = getTempName(node.getChildren()[1], temps);
                    if (derivOrder[0] == 0 && derivOrder[1] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(dot(" << child1 << ", " << child2 << "));\n";
                    else
159
                        throw OpenMMException("Unsupported derivative order for dot()");
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
                }
            }
            else if (node.getOperation().getName() == "cross") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    string child1 = getTempName(node.getChildren()[0], temps);
                    string child2 = getTempName(node.getChildren()[1], temps);
                    if (derivOrder[0] == 0 && derivOrder[1] == 0)
                        out << nodeNames[j] << " = cross(" << child1 << ", " << child2 << ");\n";
                    else
                        throw OpenMMException("Unsupported derivative order for cross()");
                }
            }
            else if (node.getOperation().getName() == "vector") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                        out << nodeNames[j] << ".x = " << getTempName(node.getChildren()[0], temps) << ".x;\n";
                        out << nodeNames[j] << ".y = " << getTempName(node.getChildren()[1], temps) << ".y;\n";
                        out << nodeNames[j] << ".z = " << getTempName(node.getChildren()[2], temps) << ".z;\n";
                    }
                    else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0)
                        out << nodeNames[j] << ".x = 1;\n";
                    else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0)
                        out << nodeNames[j] << ".y = 1;\n";
                    else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1)
                        out << nodeNames[j] << ".z = 1;\n";
                }
            }
            else if (node.getOperation().getName() == "_x") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(" << getTempName(node.getChildren()[0], temps) << ".x);\n";
                    else
                        throw OpenMMException("Unsupported derivative order for _x()");
                }
            }
            else if (node.getOperation().getName() == "_y") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(" << getTempName(node.getChildren()[0], temps) << ".y);\n";
                    else
                        throw OpenMMException("Unsupported derivative order for _y()");
                }
            }
            else if (node.getOperation().getName() == "_z") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(" << getTempName(node.getChildren()[0], temps) << ".z);\n";
                    else
                        throw OpenMMException("Unsupported derivative order for _z()");
                }
            }
216
217
            else {
                // This is a tabulated function.
218

219
220
221
222
223
224
225
226
227
228
                int i;
                for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
                    ;
                if (i == functionNames.size())
                    throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
                vector<string> paramsFloat, paramsInt;
                for (int j = 0; j < (int) functionParams[i].size(); j++) {
                    paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
                    paramsInt.push_back(context.intToString((int) functionParams[i][j]));
                }
229
230
231
232
233
                vector<string> suffixes;
                if (isVecType) {
                    suffixes.push_back(".x");
                    suffixes.push_back(".y");
                    suffixes.push_back(".z");
234
                }
235
236
                else {
                    suffixes.push_back("");
237
                }
238
239
240
                for (auto& suffix : suffixes) {
                    out << "{\n";
                    if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
241
242
243
244
245
246
247
248
249
250
251
252
253
254
                        int periodic = functionParams[i][4];
                        if (periodic) {
                            out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                            out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[5]<< ";\n";
                            out << "x = (x - floor(x))*" << paramsFloat[6] << ";\n";
                            out << "int index = (int) (floor(x));\n";
                        }
                        else {
                            out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                            out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
                            out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
                            out << "int index = (int) (floor(x));\n";
                            out << "index = min(index, (int) " << paramsInt[3] << ");\n";
                        }
255
256
257
258
259
260
261
262
263
                        out << "float4 coeff = " << functionNames[i].second << "[index];\n";
                        out << "real b = x-index;\n";
                        out << "real a = 1.0f-b;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0)
                                out << nodeNames[j] << suffix << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
                            else
                                out << nodeNames[j] << suffix << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
264
                        }
265
266
267
                        if (!periodic)
                            out << "}\n";
                      }
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                    else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                        out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n";
                        out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
                        out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
                        out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
                        out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
                        out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
                        out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
                        out << "float4 c[4];\n";
                        for (int j = 0; j < 4; j++)
                            out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
                        out << "real da = x-s;\n";
                        out << "real db = y-t;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0) {
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[3].w*db + c[3].z)*db + c[3].y)*db + c[3].x;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[2].w*db + c[2].z)*db + c[2].y)*db + c[2].x;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[1].w*db + c[1].z)*db + c[1].y)*db + c[1].x;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[0].w*db + c[0].z)*db + c[0].y)*db + c[0].x;\n";
                            }
                            else if (derivOrder[0] == 1 && derivOrder[1] == 0) {
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].w*da + 2.0f*c[2].w)*da + c[1].w;\n";
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].z*da + 2.0f*c[2].z)*da + c[1].z;\n";
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].y*da + 2.0f*c[2].y)*da + c[1].y;\n";
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].x*da + 2.0f*c[2].x)*da + c[1].x;\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[6] << ";\n";
                            }
                            else if (derivOrder[0] == 0 && derivOrder[1] == 1) {
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[3].w*db + 2.0f*c[3].z)*db + c[3].y;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[2].w*db + 2.0f*c[2].z)*db + c[2].y;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[1].w*db + 2.0f*c[1].z)*db + c[1].y;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[0].w*db + 2.0f*c[0].z)*db + c[0].y;\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[7] << ";\n";
                            }
                            else
                                throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
306
                        }
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
                        out << "}\n";
                    }
                    else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                        out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n";
                        out << "real z = " << getTempName(node.getChildren()[2], temps) << suffix << ";\n";
                        out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
                        out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
                        out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
                        out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n";
                        out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
                        out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
                        out << "int u = min((int) floor(z), " << paramsInt[2] << "-1);\n";
                        out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
                        out << "float4 c[16];\n";
                        for (int j = 0; j < 16; j++)
                            out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
                        out << "real da = x-s;\n";
                        out << "real db = y-t;\n";
                        out << "real dc = z-u;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                                out << "real value[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = k + 4*m;
                                        out << "value[" << m << "] = db*value[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
                                    }
                                out << nodeNames[j] << suffix << " = value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));\n";
                            }
                            else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                                out << "real derivx[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = k + 4*m;
                                        out << "derivx[" << m << "] = db*derivx[" << m << "] + (3*c[" << base << "].w*da + 2*c[" << base << "].z)*da + c[" << base << "].y;\n";
                                    }
                                out << nodeNames[j] << suffix << " = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[9] << ";\n";
                            }
                            else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0) {
                                const string suffixes[] = {".x", ".y", ".z", ".w"};
                                out << "real derivy[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = 4*m;
                                        string suffix = suffixes[k];
                                        out << "derivy[" << m << "] = da*derivy[" << m << "] + (3*c[" << (base+3) << "]" << suffix << "*db + 2*c[" << (base+2) << "]" << suffix << ")*db + c[" << (base+1) << "]" << suffix << ";\n";
                                    }
                                out << nodeNames[j] << suffix << " = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[10] << ";\n";
                            }
                            else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1) {
                                out << "real derivz[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = k + 4*m;
                                        out << "derivz[" << m << "] = db*derivz[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
                                    }
                                out << nodeNames[j] << suffix << " = derivz[1] + dc*(2*derivz[2] + dc*3*derivz[3]);\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[11] << ";\n";
                            }
                            else
                                throw OpenMMException("Unsupported derivative order for Continuous3DFunction");
372
                        }
373
                        out << "}\n";
374
                    }
375
376
377
378
379
380
381
382
383
384
                    else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0) {
                                out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                                out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
                                out << "int index = (int) floor(x+0.5f);\n";
                                out << nodeNames[j] << suffix << " = " << functionNames[i].second << "[index];\n";
                                out << "}\n";
                            }
385
                        }
386
                    }
387
388
389
390
391
392
393
394
395
396
397
398
                    else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0) {
                                out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << suffix << "+0.5f);\n";
                                out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << suffix << "+0.5f);\n";
                                out << "int xsize = (int) " << paramsInt[0] << ";\n";
                                out << "int ysize = (int) " << paramsInt[1] << ";\n";
                                out << "int index = x+y*xsize;\n";
                                out << "if (index >= 0 && index < xsize*ysize)\n";
                                out << nodeNames[j] << suffix << " = " << functionNames[i].second << "[index];\n";
                            }
399
                        }
400
                    }
401
402
403
404
405
406
407
408
409
410
411
412
413
414
                    else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                                out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << suffix << "+0.5f);\n";
                                out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << suffix << "+0.5f);\n";
                                out << "int z = (int) floor(" << getTempName(node.getChildren()[2], temps) << suffix << "+0.5f);\n";
                                out << "int xsize = (int) " << paramsInt[0] << ";\n";
                                out << "int ysize = (int) " << paramsInt[1] << ";\n";
                                out << "int zsize = (int) " << paramsInt[2] << ";\n";
                                out << "int index = x+(y+z*ysize)*xsize;\n";
                                out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
                                out << nodeNames[j] << suffix << " = " << functionNames[i].second << "[index];\n";
                            }
415
                        }
416
                    }
417
                    out << "}\n";
peastman's avatar
peastman committed
418
419
                }
            }
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
            out << "}";
            break;
        }
        case Operation::ADD:
            out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps);
            break;
        case Operation::SUBTRACT:
            out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps);
            break;
        case Operation::MULTIPLY:
            out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
            break;
        case Operation::DIVIDE:
        {
            bool haveReciprocal = false;
435
436
437
438
439
440
441
442
443
444
445
446
447
            if (node.getChildren()[1].getOperation().getId() == Operation::RECIPROCAL) {
                for (int i = 0; i < (int) temps.size(); i++)
                    if (temps[i].first == node.getChildren()[1].getChildren()[1]) {
                        haveReciprocal = true;
                        out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second;
                    }
            }
            if (!haveReciprocal)
                for (int i = 0; i < (int) temps.size(); i++)
                    if (temps[i].first.getOperation().getId() == Operation::RECIPROCAL && temps[i].first.getChildren()[0] == node.getChildren()[1]) {
                        haveReciprocal = true;
                        out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second;
                    }
448
449
450
451
452
            if (!haveReciprocal)
                out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
            break;
        }
        case Operation::POWER:
453
            out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
454
455
456
457
458
            break;
        case Operation::NEGATE:
            out << "-" << getTempName(node.getChildren()[0], temps);
            break;
        case Operation::SQRT:
459
            callFunction(out, "sqrtf", "sqrt", getTempName(node.getChildren()[0], temps), tempType);
460
461
            break;
        case Operation::EXP:
462
            callFunction(out, "expf", "exp", getTempName(node.getChildren()[0], temps), tempType);
463
464
            break;
        case Operation::LOG:
465
            callFunction(out, "logf", "log", getTempName(node.getChildren()[0], temps), tempType);
466
467
            break;
        case Operation::SIN:
468
            callFunction(out, "sinf", "sin", getTempName(node.getChildren()[0], temps), tempType);
469
470
            break;
        case Operation::COS:
471
            callFunction(out, "cosf", "cos", getTempName(node.getChildren()[0], temps), tempType);
472
473
            break;
        case Operation::SEC:
474
475
            out << "1/";
            callFunction(out, "cosf", "cos", getTempName(node.getChildren()[0], temps), tempType);
476
477
            break;
        case Operation::CSC:
478
479
            out << "1/";
            callFunction(out, "sinf", "sin", getTempName(node.getChildren()[0], temps), tempType);
480
481
            break;
        case Operation::TAN:
482
            callFunction(out, "tanf", "tan", getTempName(node.getChildren()[0], temps), tempType);
483
484
            break;
        case Operation::COT:
485
486
            out << "1/";
            callFunction(out, "tanf", "tan", getTempName(node.getChildren()[0], temps), tempType);
487
488
            break;
        case Operation::ASIN:
489
            callFunction(out, "asinf", "asin", getTempName(node.getChildren()[0], temps), tempType);
490
491
            break;
        case Operation::ACOS:
492
            callFunction(out, "acosf", "acos", getTempName(node.getChildren()[0], temps), tempType);
493
494
            break;
        case Operation::ATAN:
495
            callFunction(out, "atanf", "atan", getTempName(node.getChildren()[0], temps), tempType);
496
            break;
497
        case Operation::ATAN2:
Peter Eastman's avatar
Peter Eastman committed
498
            callFunction2(out, "atan2f", "atan2", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
499
            break;
500
        case Operation::SINH:
501
            callFunction(out, "sinh", "sinh", getTempName(node.getChildren()[0], temps), tempType);
502
503
            break;
        case Operation::COSH:
504
            callFunction(out, "cosh", "cosh", getTempName(node.getChildren()[0], temps), tempType);
505
506
            break;
        case Operation::TANH:
507
            callFunction(out, "tanh", "tanh", getTempName(node.getChildren()[0], temps), tempType);
508
509
            break;
        case Operation::ERF:
510
            callFunction(out, "erf", "erf", getTempName(node.getChildren()[0], temps), tempType);
511
512
            break;
        case Operation::ERFC:
513
            callFunction(out, "erfc", "erfc", getTempName(node.getChildren()[0], temps), tempType);
514
515
            break;
        case Operation::STEP:
516
517
518
519
520
521
522
523
524
525
526
527
528
        {
            string compareVal = getTempName(node.getChildren()[0], temps);
            if (isVecType) {
                out << "make_" << tempType << "(0);\n";
                out << "{\n";
                out << tempType<<" tempCompareValue = " << compareVal << ";\n";
                out << name << ".x = (tempCompareValue.x >= 0 ? 1 : 0);\n";
                out << name << ".y = (tempCompareValue.y >= 0 ? 1 : 0);\n";
                out << name << ".z = (tempCompareValue.z >= 0 ? 1 : 0);\n";
                out << "}\n";
            }
            else
                out << compareVal << " >= 0 ? 1 : 0";
529
            break;
530
        }
531
        case Operation::DELTA:
532
533
534
535
536
537
538
539
540
541
542
543
544
        {
            string compareVal = getTempName(node.getChildren()[0], temps);
            if (isVecType) {
                out << "make_" << tempType << "(0);\n";
                out << "{\n";
                out << tempType<<" tempCompareValue = " << compareVal << ";\n";
                out << name << ".x = (tempCompareValue.x == 0 ? 1 : 0);\n";
                out << name << ".y = (tempCompareValue.y == 0 ? 1 : 0);\n";
                out << name << ".z = (tempCompareValue.z == 0 ? 1 : 0);\n";
                out << "}\n";
            }
            else
                out << compareVal << " == 0 ? 1 : 0";
545
            break;
546
        }
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        case Operation::SQUARE:
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg;
            break;
        }
        case Operation::CUBE:
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg << "*" << arg;
            break;
        }
        case Operation::RECIPROCAL:
            out << "RECIP(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::ADD_CONSTANT:
563
564
565
566
567
568
569
570
571
572
            if (isVecType) {
                string val = context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue());
                string arg = getTempName(node.getChildren()[0], temps);
                out << "make_" << tempType << "(";
                out << val << "+" << arg << ".x, ";
                out << val << "+" << arg << ".y, ";
                out << val << "+" << arg << ".z)";
            }
            else
                out << context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
573
574
            break;
        case Operation::MULTIPLY_CONSTANT:
575
            out << context.doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps);
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
            break;
        case Operation::POWER_CONSTANT:
        {
            double exponent = dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue();
            if (exponent == 0.0)
                out << "1.0f";
            else if (exponent == (int) exponent) {
                out << "0.0f;\n";
                temps.push_back(make_pair(node, name));
                hasRecordedNode = true;

                // If multiple integral powers of the same base are needed, it's faster to calculate all of them
                // at once, so check to see if others are also needed.

                map<int, const ExpressionTreeNode*> powers;
                powers[(int) exponent] = &node;
                for (int j = 0; j < (int) allExpressions.size(); j++)
                    findRelatedPowers(node, allExpressions[j].getRootNode(), powers);
                vector<int> exponents;
                vector<string> names;
                vector<bool> hasAssigned(powers.size(), false);
                exponents.push_back((int) fabs(exponent));
                names.push_back(name);
peastman's avatar
peastman committed
599
600
601
                for (auto& power : powers) {
                    if (power.first != exponent) {
                        exponents.push_back(power.first >= 0 ? power.first : -power.first);
602
                        string name2 = prefix+context.intToString(temps.size());
603
                        names.push_back(name2);
peastman's avatar
peastman committed
604
                        temps.push_back(make_pair(*power.second, name2));
605
606
607
608
                        out << tempType << " " << name2 << " = 0.0f;\n";
                    }
                }
                out << "{\n";
609
                out << "real multiplier = " << (exponent < 0.0 ? "RECIP(" : "(") << getTempName(node.getChildren()[0], temps) << ");\n";
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
                bool done = false;
                while (!done) {
                    done = true;
                    for (int i = 0; i < (int) exponents.size(); i++) {
                        if (exponents[i]%2 == 1) {
                            if (!hasAssigned[i])
                                out << names[i] << " = multiplier;\n";
                            else
                                out << names[i] << " *= multiplier;\n";
                            hasAssigned[i] = true;
                        }
                        exponents[i] >>= 1;
                        if (exponents[i] != 0)
                            done = false;
                    }
                    if (!done)
                        out << "multiplier *= multiplier;\n";
                }
                out << "}";
            }
            else
631
                out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << context.doubleToString(exponent) << ")";
632
633
634
            break;
        }
        case Operation::MIN:
635
            callFunction2(out, "min", "min", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
636
637
            break;
        case Operation::MAX:
638
            callFunction2(out, "max", "max", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
639
640
            break;
        case Operation::ABS:
641
            callFunction(out, "fabs", "fabs", getTempName(node.getChildren()[0], temps), tempType);
642
            break;
643
        case Operation::FLOOR:
644
            callFunction(out, "floor", "floor", getTempName(node.getChildren()[0], temps), tempType);
645
646
            break;
        case Operation::CEIL:
647
            callFunction(out, "ceil", "ceil", getTempName(node.getChildren()[0], temps), tempType);
648
            break;
649
        case Operation::SELECT:
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        {
            string compareVal = getTempName(node.getChildren()[0], temps);
            string val1 = getTempName(node.getChildren()[1], temps);
            string val2 = getTempName(node.getChildren()[2], temps);
            if (isVecType) {
                out << "make_" << tempType << "(0);\n";
                out << "{\n";
                out << tempType<<" tempCompareValue = " << compareVal << ";\n";
                out << name << ".x = (tempCompareValue.x != 0 ? " << val1 << ".x : " << val2 << ".x);\n";
                out << name << ".y = (tempCompareValue.y != 0 ? " << val1 << ".y : " << val2 << ".y);\n";
                out << name << ".z = (tempCompareValue.z != 0 ? " << val1 << ".z : " << val2 << ".z);\n";
                out << "}\n";
            }
            else
                out << "(" << compareVal << " != 0 ? " << val1 << " : " << val2 << ")";
665
            break;
666
        }
667
668
669
670
671
672
673
674
        default:
            throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
    }
    out << ";\n";
    if (!hasRecordedNode)
        temps.push_back(make_pair(node, name));
}

675
string ExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector<pair<ExpressionTreeNode, string> >& temps) {
676
677
678
679
680
681
682
683
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return temps[i].second;
    stringstream out;
    out << "Internal error: No temporary variable for expression node: " << node;
    throw OpenMMException(out.str());
}

684
void ExpressionUtilities::findRelatedCustomFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
685
            vector<const Lepton::ExpressionTreeNode*>& nodes) {
686
687
    if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
        // Make sure the arguments are identical.
688

689
690
691
        for (int i = 0; i < (int) node.getChildren().size(); i++)
            if (node.getChildren()[i] != searchNode.getChildren()[i])
                return;
692

693
        // See if we already have an identical node.
694

695
696
697
        for (int i = 0; i < (int) nodes.size(); i++)
            if (*nodes[i] == searchNode)
                return;
698

699
        // Add the node.
700

701
        nodes.push_back(&searchNode);
702
    }
703
704
    else
        for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
705
            findRelatedCustomFunctions(node, searchNode.getChildren()[i], nodes);
706
707
}

708
void ExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
    if (searchNode.getOperation().getId() == Operation::POWER_CONSTANT && node.getChildren()[0] == searchNode.getChildren()[0]) {
        double realPower = dynamic_cast<const Operation::PowerConstant*>(&searchNode.getOperation())->getValue();
        int power = (int) realPower;
        if (power != realPower)
            return; // We are only interested in integer powers.
        if (powers.find(power) != powers.end())
            return; // This power is already in the map.
        if (powers.begin()->first*power < 0)
            return; // All powers must have the same sign.
        powers[power] = &searchNode;
    }
    else
        for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
            findRelatedPowers(node, searchNode.getChildren()[i], powers);
}

725
vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) {
726
    if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
peastman's avatar
peastman committed
727
728
        // Compute the spline coefficients.

729
730
731
732
        const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function);
        vector<double> values;
        double min, max;
        fn.getFunctionParameters(values, min, max);
733
        bool periodic;
734
        periodic = fn.getPeriodic();
735
736
737
738
        int numValues = values.size();
        vector<double> x(numValues), derivs;
        for (int i = 0; i < numValues; i++)
            x[i] = min+i*(max-min)/(numValues-1);
739
740
741
742
        if (periodic)
            SplineFitter::createPeriodicSpline(x, values, derivs);
        else
            SplineFitter::createNaturalSpline(x, values, derivs);
743
744
745
746
747
748
749
        vector<float> f(4*(numValues-1));
        for (int i = 0; i < (int) values.size()-1; i++) {
            f[4*i] = (float) values[i];
            f[4*i+1] = (float) values[i+1];
            f[4*i+2] = (float) (derivs[i]/6.0);
            f[4*i+3] = (float) (derivs[i+1]/6.0);
        }
peastman's avatar
peastman committed
750
751
752
        width = 4;
        return f;
    }
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
    if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL) {
        // Compute the spline coefficients.

        const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(function);
        vector<double> values;
        int xsize, ysize;
        double xmin, xmax, ymin, ymax;
        fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
        vector<double> x(xsize), y(ysize);
        for (int i = 0; i < xsize; i++)
            x[i] = xmin+i*(xmax-xmin)/(xsize-1);
        for (int i = 0; i < ysize; i++)
            y[i] = ymin+i*(ymax-ymin)/(ysize-1);
        vector<vector<double> > c;
        SplineFitter::create2DNaturalSpline(x, y, values, c);
        vector<float> f(16*c.size());
        for (int i = 0; i < (int) c.size(); i++) {
            for (int j = 0; j < 16; j++)
                f[16*i+j] = (float) c[i][j];
        }
        width = 4;
        return f;
    }
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
    if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL) {
        // Compute the spline coefficients.

        const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(function);
        vector<double> values;
        int xsize, ysize, zsize;
        double xmin, xmax, ymin, ymax, zmin, zmax;
        fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
        vector<double> x(xsize), y(ysize), z(zsize);
        for (int i = 0; i < xsize; i++)
            x[i] = xmin+i*(xmax-xmin)/(xsize-1);
        for (int i = 0; i < ysize; i++)
            y[i] = ymin+i*(ymax-ymin)/(ysize-1);
        for (int i = 0; i < zsize; i++)
            z[i] = zmin+i*(zmax-zmin)/(zsize-1);
        vector<vector<double> > c;
        SplineFitter::create3DNaturalSpline(x, y, z, values, c);
        vector<float> f(64*c.size());
        for (int i = 0; i < (int) c.size(); i++) {
            for (int j = 0; j < 64; j++)
                f[64*i+j] = (float) c[i][j];
        }
        width = 4;
        return f;
    }
peastman's avatar
peastman committed
801
802
    if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL) {
        // Record the tabulated values.
803

peastman's avatar
peastman committed
804
805
806
807
808
809
810
811
        const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(function);
        vector<double> values;
        fn.getFunctionParameters(values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
812
813
        return f;
    }
814
815
    if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL) {
        // Record the tabulated values.
816

817
818
819
820
821
822
823
824
825
826
827
828
829
        const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(function);
        int xsize, ysize;
        vector<double> values;
        fn.getFunctionParameters(xsize, ysize, values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
        return f;
    }
    if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL) {
        // Record the tabulated values.
830

831
832
833
834
835
836
837
838
839
840
841
        const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(function);
        int xsize, ysize, zsize;
        vector<double> values;
        fn.getFunctionParameters(xsize, ysize, zsize, values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
        return f;
    }
842
843
844
    throw OpenMMException("computeFunctionCoefficients: Unknown function type");
}

845
vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
846
    vector<vector<double> > params(functions.size());
847
848
849
850
851
852
    for (int i = 0; i < (int) functions.size(); i++) {
        if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
            const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]);
            vector<double> values;
            double min, max;
            fn.getFunctionParameters(values, min, max);
853
854
855
856
            params[i].push_back(min);
            params[i].push_back(max);
            params[i].push_back((values.size()-1)/(max-min));
            params[i].push_back(values.size()-2);
857
858
859
860
            int periodic = (int) fn.getPeriodic();
            params[i].push_back(periodic);
            params[i].push_back(1.0/(max-min));
            params[i].push_back(values.size()-1);
861
        }
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
        else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
            const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(*functions[i]);
            vector<double> values;
            int xsize, ysize;
            double xmin, xmax, ymin, ymax;
            fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
            params[i].push_back(xsize-1);
            params[i].push_back(ysize-1);
            params[i].push_back(xmin);
            params[i].push_back(xmax);
            params[i].push_back(ymin);
            params[i].push_back(ymax);
            params[i].push_back((xsize-1)/(xmax-xmin));
            params[i].push_back((ysize-1)/(ymax-ymin));
        }
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
        else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
            const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(*functions[i]);
            vector<double> values;
            int xsize, ysize, zsize;
            double xmin, xmax, ymin, ymax, zmin, zmax;
            fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
            params[i].push_back(xsize-1);
            params[i].push_back(ysize-1);
            params[i].push_back(zsize-1);
            params[i].push_back(xmin);
            params[i].push_back(xmax);
            params[i].push_back(ymin);
            params[i].push_back(ymax);
            params[i].push_back(zmin);
            params[i].push_back(zmax);
            params[i].push_back((xsize-1)/(xmax-xmin));
            params[i].push_back((ysize-1)/(ymax-ymin));
            params[i].push_back((zsize-1)/(zmax-zmin));
        }
peastman's avatar
peastman committed
896
897
898
899
        else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
            const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
            vector<double> values;
            fn.getFunctionParameters(values);
900
            params[i].push_back(values.size());
peastman's avatar
peastman committed
901
        }
902
903
904
905
906
        else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
            const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
            int xsize, ysize;
            vector<double> values;
            fn.getFunctionParameters(xsize, ysize, values);
907
908
            params[i].push_back(xsize);
            params[i].push_back(ysize);
909
910
911
912
913
914
        }
        else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
            const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
            int xsize, ysize, zsize;
            vector<double> values;
            fn.getFunctionParameters(xsize, ysize, zsize, values);
915
916
917
            params[i].push_back(xsize);
            params[i].push_back(ysize);
            params[i].push_back(zsize);
918
        }
919
920
921
922
        else
            throw OpenMMException("computeFunctionParameters: Unknown function type");
    }
    return params;
923
}
924

925
Lepton::CustomFunction* ExpressionUtilities::getFunctionPlaceholder(const TabulatedFunction& function) {
926
927
    if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
        return &fp1;
928
929
    if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL)
        return &fp2;
930
931
    if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL)
        return &fp3;
932
933
934
935
936
937
938
939
    if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
        return &fp1;
    if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL)
        return &fp2;
    if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL)
        return &fp3;
    throw OpenMMException("getFunctionPlaceholder: Unknown function type");
}
940

941
Lepton::CustomFunction* ExpressionUtilities::getPeriodicDistancePlaceholder() {
942
    return &periodicDistance;
943
}
944

945
void ExpressionUtilities::callFunction(stringstream& out, string singleFn, string doubleFn, const string& arg, const string& tempType) {
946
947
    bool isDouble = (tempType[0] == 'd');
    bool isVector = (tempType[tempType.size()-1] == '3');
948
949
950
951
952
    string fn = (isDouble ? doubleFn : singleFn);
    if (isVector)
        out<<"make_"<<tempType<<"("<<fn<<"("<<arg<<".x), "<<fn<<"("<<arg<<".y), "<<fn<<"("<<arg<<".z))";
    else
        out<<fn<<"("<<arg<<")";
953
}
954

955
void ExpressionUtilities::callFunction2(stringstream& out, string singleFn, string doubleFn, const string& arg1, const string& arg2, const string& tempType) {
956
    bool isDouble = (tempType[0] == 'd');
957
    bool isVector = (tempType[tempType.size()-1] == '3');
958
    string fn = (isDouble ? doubleFn : singleFn);
959
960
961
962
963
964
965
966
967
    if (isVector) {
        out<<"make_"<<tempType<<"(";
        out<<fn<<"("<<arg1<<".x, "<<arg2<<".x), ";
        out<<fn<<"("<<arg1<<".y, "<<arg2<<".y), ";
        out<<fn<<"("<<arg1<<".z, "<<arg2<<".z))";
    }
    else
        out<<fn<<"(("<<tempType<<") "<<arg1<<", ("<<tempType<<") "<<arg2<<")";
}