ExpressionUtilities.cpp 56.6 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2019 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

27
#include "openmm/common/ExpressionUtilities.h"
28
29
30
31
32
33
34
35
#include "openmm/OpenMMException.h"
#include "openmm/internal/SplineFitter.h"
#include "lepton/Operation.h"

using namespace OpenMM;
using namespace Lepton;
using namespace std;

36
ExpressionUtilities::ExpressionUtilities(ComputeContext& context) : context(context), fp1(1), fp2(2), fp3(3), periodicDistance(6) {
37
38
}

39
string ExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
40
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
41
42
43
    vector<pair<ExpressionTreeNode, string> > variableNodes;
    for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
        variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
44
    return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
45
46
}

47
string ExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
48
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
49
50
51
52
53
    stringstream out;
    vector<ParsedExpression> allExpressions;
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
        allExpressions.push_back(iter->second);
    vector<pair<ExpressionTreeNode, string> > temps = variables;
54
    vector<vector<double> > functionParams = computeFunctionParameters(functions);
55
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
peastman's avatar
peastman committed
56
        processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
57
58
59
60
61
        out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
    }
    return out.str();
}

62
void ExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
63
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<vector<double> >& functionParams,
peastman's avatar
peastman committed
64
        const vector<ParsedExpression>& allExpressions, const string& tempType) {
65
66
67
68
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return;
    for (int i = 0; i < (int) node.getChildren().size(); i++)
peastman's avatar
peastman committed
69
        processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
70
    string name = prefix+context.intToString(temps.size());
71
    bool hasRecordedNode = false;
72
    bool isVecType = (tempType[tempType.size()-1] == '3');
73

74
75
76
    out << tempType << " " << name << " = ";
    switch (node.getOperation().getId()) {
        case Operation::CONSTANT:
77
78
79
80
81
82
        {
            string value = context.doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
            if (isVecType)
                out << "make_" << tempType << "(" << value << ")";
            else
                out << value;
83
            break;
84
        }
85
86
87
88
        case Operation::VARIABLE:
            throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
        case Operation::CUSTOM:
        {
89
90
91
92
            if (isVecType)
                out << "make_" << tempType << "(0);\n";
            else
                out << "0;\n";
93
94
95
96
97
98
            temps.push_back(make_pair(node, name));
            hasRecordedNode = true;

            // If both the value and derivative of the function are needed, it's faster to calculate them both
            // at once, so check to see if both are needed.

99
            vector<const ExpressionTreeNode*> nodes;
100
            nodes.push_back(&node);
101
            for (int j = 0; j < (int) allExpressions.size(); j++)
102
                findRelatedCustomFunctions(node, allExpressions[j].getRootNode(), nodes);
103
104
105
            vector<string> nodeNames;
            nodeNames.push_back(name);
            for (int j = 1; j < (int) nodes.size(); j++) {
106
                string name2 = prefix+context.intToString(temps.size());
107
                out << tempType << " " << name2 << " = 0.0f;\n";
108
109
                nodeNames.push_back(name2);
                temps.push_back(make_pair(*nodes[j], name2));
110
111
            }
            out << "{\n";
112
113
114
            if (node.getOperation().getName() == "periodicdistance") {
                // This is the periodicdistance() function.

115
                out << tempType << "3 periodicDistance_delta = make_real3(";
116
117
118
119
                for (int i = 0; i < 3; i++) {
                    if (i > 0)
                        out << ", ";
                    out << getTempName(node.getChildren()[i], temps) << "-" << getTempName(node.getChildren()[i+3], temps);
120
                }
121
                out << ");\n";
122
                out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
123
124
                out << tempType << " periodicDistance_r2 = periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z;\n";
                out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_r2);\n";
125
126
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
127
128
129
130
131
132
133
                    int argIndex = -1;
                    for (int k = 0; k < 6; k++) {
                        if (derivOrder[k] > 0) {
                            if (derivOrder[k] > 1 || argIndex != -1)
                                throw OpenMMException("Unsupported derivative of periodicdistance"); // Should be impossible for this to happen.
                            argIndex = k;
                        }
134
                    }
135
                    if (argIndex == -1)
136
                        out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n";
137
                    else if (argIndex == 0)
138
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
139
                    else if (argIndex == 1)
140
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
141
                    else if (argIndex == 2)
142
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
143
                    else if (argIndex == 3)
144
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
145
                    else if (argIndex == 4)
146
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
147
                    else if (argIndex == 5)
148
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
149
150
                }
            }
151
152
153
154
155
156
157
158
            else if (node.getOperation().getName() == "dot") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    string child1 = getTempName(node.getChildren()[0], temps);
                    string child2 = getTempName(node.getChildren()[1], temps);
                    if (derivOrder[0] == 0 && derivOrder[1] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(dot(" << child1 << ", " << child2 << "));\n";
                    else
159
                        throw OpenMMException("Unsupported derivative order for dot()");
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
                }
            }
            else if (node.getOperation().getName() == "cross") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    string child1 = getTempName(node.getChildren()[0], temps);
                    string child2 = getTempName(node.getChildren()[1], temps);
                    if (derivOrder[0] == 0 && derivOrder[1] == 0)
                        out << nodeNames[j] << " = cross(" << child1 << ", " << child2 << ");\n";
                    else
                        throw OpenMMException("Unsupported derivative order for cross()");
                }
            }
            else if (node.getOperation().getName() == "vector") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                        out << nodeNames[j] << ".x = " << getTempName(node.getChildren()[0], temps) << ".x;\n";
                        out << nodeNames[j] << ".y = " << getTempName(node.getChildren()[1], temps) << ".y;\n";
                        out << nodeNames[j] << ".z = " << getTempName(node.getChildren()[2], temps) << ".z;\n";
                    }
                    else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0)
                        out << nodeNames[j] << ".x = 1;\n";
                    else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0)
                        out << nodeNames[j] << ".y = 1;\n";
                    else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1)
                        out << nodeNames[j] << ".z = 1;\n";
                }
            }
            else if (node.getOperation().getName() == "_x") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(" << getTempName(node.getChildren()[0], temps) << ".x);\n";
                    else
                        throw OpenMMException("Unsupported derivative order for _x()");
                }
            }
            else if (node.getOperation().getName() == "_y") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(" << getTempName(node.getChildren()[0], temps) << ".y);\n";
                    else
                        throw OpenMMException("Unsupported derivative order for _y()");
                }
            }
            else if (node.getOperation().getName() == "_z") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(" << getTempName(node.getChildren()[0], temps) << ".z);\n";
                    else
                        throw OpenMMException("Unsupported derivative order for _z()");
                }
            }
216
217
            else {
                // This is a tabulated function.
218

219
220
221
222
223
224
225
226
227
228
                int i;
                for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
                    ;
                if (i == functionNames.size())
                    throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
                vector<string> paramsFloat, paramsInt;
                for (int j = 0; j < (int) functionParams[i].size(); j++) {
                    paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
                    paramsInt.push_back(context.intToString((int) functionParams[i][j]));
                }
229
230
231
232
233
                vector<string> suffixes;
                if (isVecType) {
                    suffixes.push_back(".x");
                    suffixes.push_back(".y");
                    suffixes.push_back(".z");
234
                }
235
236
                else {
                    suffixes.push_back("");
237
                }
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
                for (auto& suffix : suffixes) {
                    out << "{\n";
                    if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                        out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
                        out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
                        out << "int index = (int) (floor(x));\n";
                        out << "index = min(index, (int) " << paramsInt[3] << ");\n";
                        out << "float4 coeff = " << functionNames[i].second << "[index];\n";
                        out << "real b = x-index;\n";
                        out << "real a = 1.0f-b;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0)
                                out << nodeNames[j] << suffix << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
                            else
                                out << nodeNames[j] << suffix << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
255
                        }
256
257
                        out << "}\n";
                    }
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
                    else if (dynamic_cast<const ContinuousPeriodic1DFunction*>(functions[i]) != NULL) {
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                        out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
                        out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
                        out << "int index = (int) (floor(x));\n";
                        out << "index = min(index, (int) " << paramsInt[3] << ");\n";
                        out << "float4 coeff = " << functionNames[i].second << "[index];\n";
                        out << "real b = x-index;\n";
                        out << "real a = 1.0f-b;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0)
                                out << nodeNames[j] << suffix << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
                            else
                                out << nodeNames[j] << suffix << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
                        }
                        out << "}\n";
                    }
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
                    else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                        out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n";
                        out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
                        out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
                        out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
                        out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
                        out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
                        out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
                        out << "float4 c[4];\n";
                        for (int j = 0; j < 4; j++)
                            out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
                        out << "real da = x-s;\n";
                        out << "real db = y-t;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0) {
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[3].w*db + c[3].z)*db + c[3].y)*db + c[3].x;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[2].w*db + c[2].z)*db + c[2].y)*db + c[2].x;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[1].w*db + c[1].z)*db + c[1].y)*db + c[1].x;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[0].w*db + c[0].z)*db + c[0].y)*db + c[0].x;\n";
                            }
                            else if (derivOrder[0] == 1 && derivOrder[1] == 0) {
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].w*da + 2.0f*c[2].w)*da + c[1].w;\n";
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].z*da + 2.0f*c[2].z)*da + c[1].z;\n";
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].y*da + 2.0f*c[2].y)*da + c[1].y;\n";
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].x*da + 2.0f*c[2].x)*da + c[1].x;\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[6] << ";\n";
                            }
                            else if (derivOrder[0] == 0 && derivOrder[1] == 1) {
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[3].w*db + 2.0f*c[3].z)*db + c[3].y;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[2].w*db + 2.0f*c[2].z)*db + c[2].y;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[1].w*db + 2.0f*c[1].z)*db + c[1].y;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[0].w*db + 2.0f*c[0].z)*db + c[0].y;\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[7] << ";\n";
                            }
                            else
                                throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
314
                        }
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
                        out << "}\n";
                    }
                    else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                        out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n";
                        out << "real z = " << getTempName(node.getChildren()[2], temps) << suffix << ";\n";
                        out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
                        out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
                        out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
                        out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n";
                        out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
                        out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
                        out << "int u = min((int) floor(z), " << paramsInt[2] << "-1);\n";
                        out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
                        out << "float4 c[16];\n";
                        for (int j = 0; j < 16; j++)
                            out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
                        out << "real da = x-s;\n";
                        out << "real db = y-t;\n";
                        out << "real dc = z-u;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                                out << "real value[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = k + 4*m;
                                        out << "value[" << m << "] = db*value[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
                                    }
                                out << nodeNames[j] << suffix << " = value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));\n";
                            }
                            else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                                out << "real derivx[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = k + 4*m;
                                        out << "derivx[" << m << "] = db*derivx[" << m << "] + (3*c[" << base << "].w*da + 2*c[" << base << "].z)*da + c[" << base << "].y;\n";
                                    }
                                out << nodeNames[j] << suffix << " = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[9] << ";\n";
                            }
                            else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0) {
                                const string suffixes[] = {".x", ".y", ".z", ".w"};
                                out << "real derivy[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = 4*m;
                                        string suffix = suffixes[k];
                                        out << "derivy[" << m << "] = da*derivy[" << m << "] + (3*c[" << (base+3) << "]" << suffix << "*db + 2*c[" << (base+2) << "]" << suffix << ")*db + c[" << (base+1) << "]" << suffix << ";\n";
                                    }
                                out << nodeNames[j] << suffix << " = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[10] << ";\n";
                            }
                            else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1) {
                                out << "real derivz[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = k + 4*m;
                                        out << "derivz[" << m << "] = db*derivz[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
                                    }
                                out << nodeNames[j] << suffix << " = derivz[1] + dc*(2*derivz[2] + dc*3*derivz[3]);\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[11] << ";\n";
                            }
                            else
                                throw OpenMMException("Unsupported derivative order for Continuous3DFunction");
380
                        }
381
                        out << "}\n";
382
                    }
383
384
385
386
387
388
389
390
391
392
                    else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0) {
                                out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                                out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
                                out << "int index = (int) floor(x+0.5f);\n";
                                out << nodeNames[j] << suffix << " = " << functionNames[i].second << "[index];\n";
                                out << "}\n";
                            }
393
                        }
394
                    }
395
396
397
398
399
400
401
402
403
404
405
406
                    else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0) {
                                out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << suffix << "+0.5f);\n";
                                out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << suffix << "+0.5f);\n";
                                out << "int xsize = (int) " << paramsInt[0] << ";\n";
                                out << "int ysize = (int) " << paramsInt[1] << ";\n";
                                out << "int index = x+y*xsize;\n";
                                out << "if (index >= 0 && index < xsize*ysize)\n";
                                out << nodeNames[j] << suffix << " = " << functionNames[i].second << "[index];\n";
                            }
407
                        }
408
                    }
409
410
411
412
413
414
415
416
417
418
419
420
421
422
                    else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                                out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << suffix << "+0.5f);\n";
                                out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << suffix << "+0.5f);\n";
                                out << "int z = (int) floor(" << getTempName(node.getChildren()[2], temps) << suffix << "+0.5f);\n";
                                out << "int xsize = (int) " << paramsInt[0] << ";\n";
                                out << "int ysize = (int) " << paramsInt[1] << ";\n";
                                out << "int zsize = (int) " << paramsInt[2] << ";\n";
                                out << "int index = x+(y+z*ysize)*xsize;\n";
                                out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
                                out << nodeNames[j] << suffix << " = " << functionNames[i].second << "[index];\n";
                            }
423
                        }
424
                    }
425
                    out << "}\n";
peastman's avatar
peastman committed
426
427
                }
            }
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
            out << "}";
            break;
        }
        case Operation::ADD:
            out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps);
            break;
        case Operation::SUBTRACT:
            out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps);
            break;
        case Operation::MULTIPLY:
            out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
            break;
        case Operation::DIVIDE:
        {
            bool haveReciprocal = false;
443
444
445
446
447
448
449
450
451
452
453
454
455
            if (node.getChildren()[1].getOperation().getId() == Operation::RECIPROCAL) {
                for (int i = 0; i < (int) temps.size(); i++)
                    if (temps[i].first == node.getChildren()[1].getChildren()[1]) {
                        haveReciprocal = true;
                        out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second;
                    }
            }
            if (!haveReciprocal)
                for (int i = 0; i < (int) temps.size(); i++)
                    if (temps[i].first.getOperation().getId() == Operation::RECIPROCAL && temps[i].first.getChildren()[0] == node.getChildren()[1]) {
                        haveReciprocal = true;
                        out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second;
                    }
456
457
458
459
460
            if (!haveReciprocal)
                out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
            break;
        }
        case Operation::POWER:
461
            out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
462
463
464
465
466
            break;
        case Operation::NEGATE:
            out << "-" << getTempName(node.getChildren()[0], temps);
            break;
        case Operation::SQRT:
467
            callFunction(out, "sqrtf", "sqrt", getTempName(node.getChildren()[0], temps), tempType);
468
469
            break;
        case Operation::EXP:
470
            callFunction(out, "expf", "exp", getTempName(node.getChildren()[0], temps), tempType);
471
472
            break;
        case Operation::LOG:
473
            callFunction(out, "logf", "log", getTempName(node.getChildren()[0], temps), tempType);
474
475
            break;
        case Operation::SIN:
476
            callFunction(out, "sinf", "sin", getTempName(node.getChildren()[0], temps), tempType);
477
478
            break;
        case Operation::COS:
479
            callFunction(out, "cosf", "cos", getTempName(node.getChildren()[0], temps), tempType);
480
481
            break;
        case Operation::SEC:
482
483
            out << "1/";
            callFunction(out, "cosf", "cos", getTempName(node.getChildren()[0], temps), tempType);
484
485
            break;
        case Operation::CSC:
486
487
            out << "1/";
            callFunction(out, "sinf", "sin", getTempName(node.getChildren()[0], temps), tempType);
488
489
            break;
        case Operation::TAN:
490
            callFunction(out, "tanf", "tan", getTempName(node.getChildren()[0], temps), tempType);
491
492
            break;
        case Operation::COT:
493
494
            out << "1/";
            callFunction(out, "tanf", "tan", getTempName(node.getChildren()[0], temps), tempType);
495
496
            break;
        case Operation::ASIN:
497
            callFunction(out, "asinf", "asin", getTempName(node.getChildren()[0], temps), tempType);
498
499
            break;
        case Operation::ACOS:
500
            callFunction(out, "acosf", "acos", getTempName(node.getChildren()[0], temps), tempType);
501
502
            break;
        case Operation::ATAN:
503
            callFunction(out, "atanf", "atan", getTempName(node.getChildren()[0], temps), tempType);
504
            break;
505
        case Operation::ATAN2:
Peter Eastman's avatar
Peter Eastman committed
506
            callFunction2(out, "atan2f", "atan2", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
507
            break;
508
        case Operation::SINH:
509
            callFunction(out, "sinh", "sinh", getTempName(node.getChildren()[0], temps), tempType);
510
511
            break;
        case Operation::COSH:
512
            callFunction(out, "cosh", "cosh", getTempName(node.getChildren()[0], temps), tempType);
513
514
            break;
        case Operation::TANH:
515
            callFunction(out, "tanh", "tanh", getTempName(node.getChildren()[0], temps), tempType);
516
517
            break;
        case Operation::ERF:
518
            callFunction(out, "erf", "erf", getTempName(node.getChildren()[0], temps), tempType);
519
520
            break;
        case Operation::ERFC:
521
            callFunction(out, "erfc", "erfc", getTempName(node.getChildren()[0], temps), tempType);
522
523
            break;
        case Operation::STEP:
524
525
526
527
528
529
530
531
532
533
534
535
536
        {
            string compareVal = getTempName(node.getChildren()[0], temps);
            if (isVecType) {
                out << "make_" << tempType << "(0);\n";
                out << "{\n";
                out << tempType<<" tempCompareValue = " << compareVal << ";\n";
                out << name << ".x = (tempCompareValue.x >= 0 ? 1 : 0);\n";
                out << name << ".y = (tempCompareValue.y >= 0 ? 1 : 0);\n";
                out << name << ".z = (tempCompareValue.z >= 0 ? 1 : 0);\n";
                out << "}\n";
            }
            else
                out << compareVal << " >= 0 ? 1 : 0";
537
            break;
538
        }
539
        case Operation::DELTA:
540
541
542
543
544
545
546
547
548
549
550
551
552
        {
            string compareVal = getTempName(node.getChildren()[0], temps);
            if (isVecType) {
                out << "make_" << tempType << "(0);\n";
                out << "{\n";
                out << tempType<<" tempCompareValue = " << compareVal << ";\n";
                out << name << ".x = (tempCompareValue.x == 0 ? 1 : 0);\n";
                out << name << ".y = (tempCompareValue.y == 0 ? 1 : 0);\n";
                out << name << ".z = (tempCompareValue.z == 0 ? 1 : 0);\n";
                out << "}\n";
            }
            else
                out << compareVal << " == 0 ? 1 : 0";
553
            break;
554
        }
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
        case Operation::SQUARE:
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg;
            break;
        }
        case Operation::CUBE:
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg << "*" << arg;
            break;
        }
        case Operation::RECIPROCAL:
            out << "RECIP(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::ADD_CONSTANT:
571
572
573
574
575
576
577
578
579
580
            if (isVecType) {
                string val = context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue());
                string arg = getTempName(node.getChildren()[0], temps);
                out << "make_" << tempType << "(";
                out << val << "+" << arg << ".x, ";
                out << val << "+" << arg << ".y, ";
                out << val << "+" << arg << ".z)";
            }
            else
                out << context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
581
582
            break;
        case Operation::MULTIPLY_CONSTANT:
583
            out << context.doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps);
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
            break;
        case Operation::POWER_CONSTANT:
        {
            double exponent = dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue();
            if (exponent == 0.0)
                out << "1.0f";
            else if (exponent == (int) exponent) {
                out << "0.0f;\n";
                temps.push_back(make_pair(node, name));
                hasRecordedNode = true;

                // If multiple integral powers of the same base are needed, it's faster to calculate all of them
                // at once, so check to see if others are also needed.

                map<int, const ExpressionTreeNode*> powers;
                powers[(int) exponent] = &node;
                for (int j = 0; j < (int) allExpressions.size(); j++)
                    findRelatedPowers(node, allExpressions[j].getRootNode(), powers);
                vector<int> exponents;
                vector<string> names;
                vector<bool> hasAssigned(powers.size(), false);
                exponents.push_back((int) fabs(exponent));
                names.push_back(name);
peastman's avatar
peastman committed
607
608
609
                for (auto& power : powers) {
                    if (power.first != exponent) {
                        exponents.push_back(power.first >= 0 ? power.first : -power.first);
610
                        string name2 = prefix+context.intToString(temps.size());
611
                        names.push_back(name2);
peastman's avatar
peastman committed
612
                        temps.push_back(make_pair(*power.second, name2));
613
614
615
616
                        out << tempType << " " << name2 << " = 0.0f;\n";
                    }
                }
                out << "{\n";
617
                out << "real multiplier = " << (exponent < 0.0 ? "RECIP(" : "(") << getTempName(node.getChildren()[0], temps) << ");\n";
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
                bool done = false;
                while (!done) {
                    done = true;
                    for (int i = 0; i < (int) exponents.size(); i++) {
                        if (exponents[i]%2 == 1) {
                            if (!hasAssigned[i])
                                out << names[i] << " = multiplier;\n";
                            else
                                out << names[i] << " *= multiplier;\n";
                            hasAssigned[i] = true;
                        }
                        exponents[i] >>= 1;
                        if (exponents[i] != 0)
                            done = false;
                    }
                    if (!done)
                        out << "multiplier *= multiplier;\n";
                }
                out << "}";
            }
            else
639
                out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << context.doubleToString(exponent) << ")";
640
641
642
            break;
        }
        case Operation::MIN:
643
            callFunction2(out, "min", "min", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
644
645
            break;
        case Operation::MAX:
646
            callFunction2(out, "max", "max", getTempName(node.getChildren()[0], temps), getTempName(node.getChildren()[1], temps), tempType);
647
648
            break;
        case Operation::ABS:
649
            callFunction(out, "fabs", "fabs", getTempName(node.getChildren()[0], temps), tempType);
650
            break;
651
        case Operation::FLOOR:
652
            callFunction(out, "floor", "floor", getTempName(node.getChildren()[0], temps), tempType);
653
654
            break;
        case Operation::CEIL:
655
            callFunction(out, "ceil", "ceil", getTempName(node.getChildren()[0], temps), tempType);
656
            break;
657
        case Operation::SELECT:
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
        {
            string compareVal = getTempName(node.getChildren()[0], temps);
            string val1 = getTempName(node.getChildren()[1], temps);
            string val2 = getTempName(node.getChildren()[2], temps);
            if (isVecType) {
                out << "make_" << tempType << "(0);\n";
                out << "{\n";
                out << tempType<<" tempCompareValue = " << compareVal << ";\n";
                out << name << ".x = (tempCompareValue.x != 0 ? " << val1 << ".x : " << val2 << ".x);\n";
                out << name << ".y = (tempCompareValue.y != 0 ? " << val1 << ".y : " << val2 << ".y);\n";
                out << name << ".z = (tempCompareValue.z != 0 ? " << val1 << ".z : " << val2 << ".z);\n";
                out << "}\n";
            }
            else
                out << "(" << compareVal << " != 0 ? " << val1 << " : " << val2 << ")";
673
            break;
674
        }
675
676
677
678
679
680
681
682
        default:
            throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
    }
    out << ";\n";
    if (!hasRecordedNode)
        temps.push_back(make_pair(node, name));
}

683
string ExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector<pair<ExpressionTreeNode, string> >& temps) {
684
685
686
687
688
689
690
691
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return temps[i].second;
    stringstream out;
    out << "Internal error: No temporary variable for expression node: " << node;
    throw OpenMMException(out.str());
}

692
void ExpressionUtilities::findRelatedCustomFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
693
            vector<const Lepton::ExpressionTreeNode*>& nodes) {
694
695
    if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
        // Make sure the arguments are identical.
696

697
698
699
        for (int i = 0; i < (int) node.getChildren().size(); i++)
            if (node.getChildren()[i] != searchNode.getChildren()[i])
                return;
700

701
        // See if we already have an identical node.
702

703
704
705
        for (int i = 0; i < (int) nodes.size(); i++)
            if (*nodes[i] == searchNode)
                return;
706

707
        // Add the node.
708

709
        nodes.push_back(&searchNode);
710
    }
711
712
    else
        for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
713
            findRelatedCustomFunctions(node, searchNode.getChildren()[i], nodes);
714
715
}

716
void ExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
    if (searchNode.getOperation().getId() == Operation::POWER_CONSTANT && node.getChildren()[0] == searchNode.getChildren()[0]) {
        double realPower = dynamic_cast<const Operation::PowerConstant*>(&searchNode.getOperation())->getValue();
        int power = (int) realPower;
        if (power != realPower)
            return; // We are only interested in integer powers.
        if (powers.find(power) != powers.end())
            return; // This power is already in the map.
        if (powers.begin()->first*power < 0)
            return; // All powers must have the same sign.
        powers[power] = &searchNode;
    }
    else
        for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
            findRelatedPowers(node, searchNode.getChildren()[i], powers);
}

733
vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) {
734
    if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
peastman's avatar
peastman committed
735
736
        // Compute the spline coefficients.

737
738
739
740
        const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function);
        vector<double> values;
        double min, max;
        fn.getFunctionParameters(values, min, max);
741
742
        bool periodic;
        fn.getPeriodicityStatus(periodic);
743
744
745
746
        int numValues = values.size();
        vector<double> x(numValues), derivs;
        for (int i = 0; i < numValues; i++)
            x[i] = min+i*(max-min)/(numValues-1);
747
748
749
750
        if (periodic)
            SplineFitter::createPeriodicSpline(x, values, derivs);
        else
            SplineFitter::createNaturalSpline(x, values, derivs);
751
752
753
754
755
756
757
        vector<float> f(4*(numValues-1));
        for (int i = 0; i < (int) values.size()-1; i++) {
            f[4*i] = (float) values[i];
            f[4*i+1] = (float) values[i+1];
            f[4*i+2] = (float) (derivs[i]/6.0);
            f[4*i+3] = (float) (derivs[i+1]/6.0);
        }
peastman's avatar
peastman committed
758
759
760
        width = 4;
        return f;
    }
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
    if (dynamic_cast<const ContinuousPeriodic1DFunction*>(&function) != NULL) {
        // Compute the spline coefficients.

        const ContinuousPeriodic1DFunction& fn = dynamic_cast<const ContinuousPeriodic1DFunction&>(function);
        vector<double> values;
        double min, max;
        fn.getFunctionParameters(values, min, max);
        int numValues = values.size();
        vector<double> x(numValues), derivs;
        for (int i = 0; i < numValues; i++)
            x[i] = min+i*(max-min)/(numValues-1);
        SplineFitter::createPeriodicSpline(x, values, derivs);
        vector<float> f(4*(numValues-1));
        for (int i = 0; i < (int) values.size()-1; i++) {
            f[4*i] = (float) values[i];
            f[4*i+1] = (float) values[i+1];
            f[4*i+2] = (float) (derivs[i]/6.0);
            f[4*i+3] = (float) (derivs[i+1]/6.0);
        }
        width = 4;
        return f;
    }
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
    if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL) {
        // Compute the spline coefficients.

        const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(function);
        vector<double> values;
        int xsize, ysize;
        double xmin, xmax, ymin, ymax;
        fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
        vector<double> x(xsize), y(ysize);
        for (int i = 0; i < xsize; i++)
            x[i] = xmin+i*(xmax-xmin)/(xsize-1);
        for (int i = 0; i < ysize; i++)
            y[i] = ymin+i*(ymax-ymin)/(ysize-1);
        vector<vector<double> > c;
        SplineFitter::create2DNaturalSpline(x, y, values, c);
        vector<float> f(16*c.size());
        for (int i = 0; i < (int) c.size(); i++) {
            for (int j = 0; j < 16; j++)
                f[16*i+j] = (float) c[i][j];
        }
        width = 4;
        return f;
    }
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
    if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL) {
        // Compute the spline coefficients.

        const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(function);
        vector<double> values;
        int xsize, ysize, zsize;
        double xmin, xmax, ymin, ymax, zmin, zmax;
        fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
        vector<double> x(xsize), y(ysize), z(zsize);
        for (int i = 0; i < xsize; i++)
            x[i] = xmin+i*(xmax-xmin)/(xsize-1);
        for (int i = 0; i < ysize; i++)
            y[i] = ymin+i*(ymax-ymin)/(ysize-1);
        for (int i = 0; i < zsize; i++)
            z[i] = zmin+i*(zmax-zmin)/(zsize-1);
        vector<vector<double> > c;
        SplineFitter::create3DNaturalSpline(x, y, z, values, c);
        vector<float> f(64*c.size());
        for (int i = 0; i < (int) c.size(); i++) {
            for (int j = 0; j < 64; j++)
                f[64*i+j] = (float) c[i][j];
        }
        width = 4;
        return f;
    }
peastman's avatar
peastman committed
831
832
    if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL) {
        // Record the tabulated values.
833

peastman's avatar
peastman committed
834
835
836
837
838
839
840
841
        const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(function);
        vector<double> values;
        fn.getFunctionParameters(values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
842
843
        return f;
    }
844
845
    if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL) {
        // Record the tabulated values.
846

847
848
849
850
851
852
853
854
855
856
857
858
859
        const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(function);
        int xsize, ysize;
        vector<double> values;
        fn.getFunctionParameters(xsize, ysize, values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
        return f;
    }
    if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL) {
        // Record the tabulated values.
860

861
862
863
864
865
866
867
868
869
870
871
        const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(function);
        int xsize, ysize, zsize;
        vector<double> values;
        fn.getFunctionParameters(xsize, ysize, zsize, values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
        return f;
    }
872
873
874
    throw OpenMMException("computeFunctionCoefficients: Unknown function type");
}

875
vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
876
    vector<vector<double> > params(functions.size());
877
878
879
880
881
882
    for (int i = 0; i < (int) functions.size(); i++) {
        if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
            const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]);
            vector<double> values;
            double min, max;
            fn.getFunctionParameters(values, min, max);
883
884
885
886
            params[i].push_back(min);
            params[i].push_back(max);
            params[i].push_back((values.size()-1)/(max-min));
            params[i].push_back(values.size()-2);
887
        }
888
889
890
891
892
893
894
895
896
897
        else if (dynamic_cast<const ContinuousPeriodic1DFunction*>(functions[i]) != NULL) {
            const ContinuousPeriodic1DFunction& fn = dynamic_cast<const ContinuousPeriodic1DFunction&>(*functions[i]);
            vector<double> values;
            double min, max;
            fn.getFunctionParameters(values, min, max);
            params[i].push_back(min);
            params[i].push_back(max);
            params[i].push_back((values.size()-1)/(max-min));
            params[i].push_back(values.size()-2);
        }
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
        else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
            const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(*functions[i]);
            vector<double> values;
            int xsize, ysize;
            double xmin, xmax, ymin, ymax;
            fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
            params[i].push_back(xsize-1);
            params[i].push_back(ysize-1);
            params[i].push_back(xmin);
            params[i].push_back(xmax);
            params[i].push_back(ymin);
            params[i].push_back(ymax);
            params[i].push_back((xsize-1)/(xmax-xmin));
            params[i].push_back((ysize-1)/(ymax-ymin));
        }
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
        else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
            const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(*functions[i]);
            vector<double> values;
            int xsize, ysize, zsize;
            double xmin, xmax, ymin, ymax, zmin, zmax;
            fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
            params[i].push_back(xsize-1);
            params[i].push_back(ysize-1);
            params[i].push_back(zsize-1);
            params[i].push_back(xmin);
            params[i].push_back(xmax);
            params[i].push_back(ymin);
            params[i].push_back(ymax);
            params[i].push_back(zmin);
            params[i].push_back(zmax);
            params[i].push_back((xsize-1)/(xmax-xmin));
            params[i].push_back((ysize-1)/(ymax-ymin));
            params[i].push_back((zsize-1)/(zmax-zmin));
        }
peastman's avatar
peastman committed
932
933
934
935
        else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
            const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
            vector<double> values;
            fn.getFunctionParameters(values);
936
            params[i].push_back(values.size());
peastman's avatar
peastman committed
937
        }
938
939
940
941
942
        else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
            const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
            int xsize, ysize;
            vector<double> values;
            fn.getFunctionParameters(xsize, ysize, values);
943
944
            params[i].push_back(xsize);
            params[i].push_back(ysize);
945
946
947
948
949
950
        }
        else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
            const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
            int xsize, ysize, zsize;
            vector<double> values;
            fn.getFunctionParameters(xsize, ysize, zsize, values);
951
952
953
            params[i].push_back(xsize);
            params[i].push_back(ysize);
            params[i].push_back(zsize);
954
        }
955
956
957
958
        else
            throw OpenMMException("computeFunctionParameters: Unknown function type");
    }
    return params;
959
}
960

961
Lepton::CustomFunction* ExpressionUtilities::getFunctionPlaceholder(const TabulatedFunction& function) {
962
963
    if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
        return &fp1;
964
965
    if (dynamic_cast<const ContinuousPeriodic1DFunction*>(&function) != NULL)
        return &fp1;
966
967
    if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL)
        return &fp2;
968
969
    if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL)
        return &fp3;
970
971
972
973
974
975
976
977
    if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
        return &fp1;
    if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL)
        return &fp2;
    if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL)
        return &fp3;
    throw OpenMMException("getFunctionPlaceholder: Unknown function type");
}
978

979
Lepton::CustomFunction* ExpressionUtilities::getPeriodicDistancePlaceholder() {
980
    return &periodicDistance;
981
}
982

983
void ExpressionUtilities::callFunction(stringstream& out, string singleFn, string doubleFn, const string& arg, const string& tempType) {
984
985
    bool isDouble = (tempType[0] == 'd');
    bool isVector = (tempType[tempType.size()-1] == '3');
986
987
988
989
990
    string fn = (isDouble ? doubleFn : singleFn);
    if (isVector)
        out<<"make_"<<tempType<<"("<<fn<<"("<<arg<<".x), "<<fn<<"("<<arg<<".y), "<<fn<<"("<<arg<<".z))";
    else
        out<<fn<<"("<<arg<<")";
991
}
992

993
void ExpressionUtilities::callFunction2(stringstream& out, string singleFn, string doubleFn, const string& arg1, const string& arg2, const string& tempType) {
994
    bool isDouble = (tempType[0] == 'd');
995
    bool isVector = (tempType[tempType.size()-1] == '3');
996
    string fn = (isDouble ? doubleFn : singleFn);
997
998
999
1000
1001
1002
1003
1004
1005
    if (isVector) {
        out<<"make_"<<tempType<<"(";
        out<<fn<<"("<<arg1<<".x, "<<arg2<<".x), ";
        out<<fn<<"("<<arg1<<".y, "<<arg2<<".y), ";
        out<<fn<<"("<<arg1<<".z, "<<arg2<<".z))";
    }
    else
        out<<fn<<"(("<<tempType<<") "<<arg1<<", ("<<tempType<<") "<<arg2<<")";
}