CudaExpressionUtilities.cpp 50.2 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2018 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "CudaExpressionUtilities.h"
#include "openmm/OpenMMException.h"
#include "openmm/internal/SplineFitter.h"
#include "lepton/Operation.h"

using namespace OpenMM;
using namespace Lepton;
using namespace std;

36
CudaExpressionUtilities::CudaExpressionUtilities(CudaContext& context) : context(context), fp1(1), fp2(2), fp3(3), periodicDistance(6) {
37
38
}

39
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
40
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
41
42
43
    vector<pair<ExpressionTreeNode, string> > variableNodes;
    for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
        variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
44
    return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
45
46
47
}

string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
48
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
49
50
51
52
53
    stringstream out;
    vector<ParsedExpression> allExpressions;
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
        allExpressions.push_back(iter->second);
    vector<pair<ExpressionTreeNode, string> > temps = variables;
54
    vector<vector<double> > functionParams = computeFunctionParameters(functions);
55
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
peastman's avatar
peastman committed
56
        processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
57
58
59
60
61
62
        out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
    }
    return out.str();
}

void CudaExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
63
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<vector<double> >& functionParams,
peastman's avatar
peastman committed
64
        const vector<ParsedExpression>& allExpressions, const string& tempType) {
65
66
67
68
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return;
    for (int i = 0; i < (int) node.getChildren().size(); i++)
peastman's avatar
peastman committed
69
        processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
70
    string name = prefix+context.intToString(temps.size());
71
    bool hasRecordedNode = false;
72
    bool isVecType = (tempType[tempType.size()-1] == '3');
73
74
75
76
    
    out << tempType << " " << name << " = ";
    switch (node.getOperation().getId()) {
        case Operation::CONSTANT:
77
78
79
80
81
82
        {
            string value = context.doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
            if (isVecType)
                out << "make_" << tempType << "(" << value << ")";
            else
                out << value;
83
            break;
84
        }
85
86
87
88
        case Operation::VARIABLE:
            throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
        case Operation::CUSTOM:
        {
89
            out << "make_" << tempType << "(0);\n";
90
91
92
93
94
95
            temps.push_back(make_pair(node, name));
            hasRecordedNode = true;

            // If both the value and derivative of the function are needed, it's faster to calculate them both
            // at once, so check to see if both are needed.

96
            vector<const ExpressionTreeNode*> nodes;
97
            for (int j = 0; j < (int) allExpressions.size(); j++)
98
                findRelatedCustomFunctions(node, allExpressions[j].getRootNode(), nodes);
99
100
101
            vector<string> nodeNames;
            nodeNames.push_back(name);
            for (int j = 1; j < (int) nodes.size(); j++) {
102
                string name2 = prefix+context.intToString(temps.size());
103
                out << tempType << " " << name2 << " = 0.0f;\n";
104
105
                nodeNames.push_back(name2);
                temps.push_back(make_pair(*nodes[j], name2));
106
107
            }
            out << "{\n";
108
109
110
            if (node.getOperation().getName() == "periodicdistance") {
                // This is the periodicdistance() function.

111
                out << tempType << "3 periodicDistance_delta = make_real3(";
112
113
114
115
                for (int i = 0; i < 3; i++) {
                    if (i > 0)
                        out << ", ";
                    out << getTempName(node.getChildren()[i], temps) << "-" << getTempName(node.getChildren()[i+3], temps);
116
                }
117
                out << ");\n";
118
                out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
119
120
                out << tempType << " periodicDistance_r2 = periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z;\n";
                out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_r2);\n";
121
122
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
123
124
125
126
127
128
129
                    int argIndex = -1;
                    for (int k = 0; k < 6; k++) {
                        if (derivOrder[k] > 0) {
                            if (derivOrder[k] > 1 || argIndex != -1)
                                throw OpenMMException("Unsupported derivative of periodicdistance"); // Should be impossible for this to happen.
                            argIndex = k;
                        }
130
                    }
131
                    if (argIndex == -1)
132
                        out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n";
133
                    else if (argIndex == 0)
134
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
135
                    else if (argIndex == 1)
136
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
137
                    else if (argIndex == 2)
138
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
139
                    else if (argIndex == 3)
140
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
141
                    else if (argIndex == 4)
142
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
143
                    else if (argIndex == 5)
144
                        out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
145
146
                }
            }
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
            else if (node.getOperation().getName() == "dot") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    string child1 = getTempName(node.getChildren()[0], temps);
                    string child2 = getTempName(node.getChildren()[1], temps);
                    if (derivOrder[0] == 0 && derivOrder[1] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(dot(" << child1 << ", " << child2 << "));\n";
                    else
                        throw OpenMMException("Unsupported derivative order for cross()");
                }
            }
            else if (node.getOperation().getName() == "cross") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    string child1 = getTempName(node.getChildren()[0], temps);
                    string child2 = getTempName(node.getChildren()[1], temps);
                    if (derivOrder[0] == 0 && derivOrder[1] == 0)
                        out << nodeNames[j] << " = cross(" << child1 << ", " << child2 << ");\n";
                    else
                        throw OpenMMException("Unsupported derivative order for cross()");
                }
            }
            else if (node.getOperation().getName() == "vector") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                        out << nodeNames[j] << ".x = " << getTempName(node.getChildren()[0], temps) << ".x;\n";
                        out << nodeNames[j] << ".y = " << getTempName(node.getChildren()[1], temps) << ".y;\n";
                        out << nodeNames[j] << ".z = " << getTempName(node.getChildren()[2], temps) << ".z;\n";
                    }
                    else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0)
                        out << nodeNames[j] << ".x = 1;\n";
                    else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0)
                        out << nodeNames[j] << ".y = 1;\n";
                    else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1)
                        out << nodeNames[j] << ".z = 1;\n";
                }
            }
            else if (node.getOperation().getName() == "_x") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(" << getTempName(node.getChildren()[0], temps) << ".x);\n";
                    else
                        throw OpenMMException("Unsupported derivative order for _x()");
                }
            }
            else if (node.getOperation().getName() == "_y") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(" << getTempName(node.getChildren()[0], temps) << ".y);\n";
                    else
                        throw OpenMMException("Unsupported derivative order for _y()");
                }
            }
            else if (node.getOperation().getName() == "_z") {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
                        out << nodeNames[j] << " = make_" << tempType << "(" << getTempName(node.getChildren()[0], temps) << ".z);\n";
                    else
                        throw OpenMMException("Unsupported derivative order for _z()");
                }
            }
212
213
214
215
216
217
218
219
220
221
222
223
224
            else {
                // This is a tabulated function.
                
                int i;
                for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
                    ;
                if (i == functionNames.size())
                    throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
                vector<string> paramsFloat, paramsInt;
                for (int j = 0; j < (int) functionParams[i].size(); j++) {
                    paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
                    paramsInt.push_back(context.intToString((int) functionParams[i][j]));
                }
225
226
227
228
229
                vector<string> suffixes;
                if (isVecType) {
                    suffixes.push_back(".x");
                    suffixes.push_back(".y");
                    suffixes.push_back(".z");
230
                }
231
232
                else {
                    suffixes.push_back("");
233
                }
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
                for (auto& suffix : suffixes) {
                    out << "{\n";
                    if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                        out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
                        out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
                        out << "int index = (int) (floor(x));\n";
                        out << "index = min(index, (int) " << paramsInt[3] << ");\n";
                        out << "float4 coeff = " << functionNames[i].second << "[index];\n";
                        out << "real b = x-index;\n";
                        out << "real a = 1.0f-b;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0)
                                out << nodeNames[j] << suffix << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
                            else
                                out << nodeNames[j] << suffix << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
251
                        }
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
                        out << "}\n";
                    }
                    else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                        out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n";
                        out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
                        out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
                        out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
                        out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
                        out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
                        out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
                        out << "float4 c[4];\n";
                        for (int j = 0; j < 4; j++)
                            out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
                        out << "real da = x-s;\n";
                        out << "real db = y-t;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0) {
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[3].w*db + c[3].z)*db + c[3].y)*db + c[3].x;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[2].w*db + c[2].z)*db + c[2].y)*db + c[2].x;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[1].w*db + c[1].z)*db + c[1].y)*db + c[1].x;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + ((c[0].w*db + c[0].z)*db + c[0].y)*db + c[0].x;\n";
                            }
                            else if (derivOrder[0] == 1 && derivOrder[1] == 0) {
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].w*da + 2.0f*c[2].w)*da + c[1].w;\n";
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].z*da + 2.0f*c[2].z)*da + c[1].z;\n";
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].y*da + 2.0f*c[2].y)*da + c[1].y;\n";
                                out << nodeNames[j] << suffix << " = db*" << nodeNames[j] << suffix << " + (3.0f*c[3].x*da + 2.0f*c[2].x)*da + c[1].x;\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[6] << ";\n";
                            }
                            else if (derivOrder[0] == 0 && derivOrder[1] == 1) {
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[3].w*db + 2.0f*c[3].z)*db + c[3].y;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[2].w*db + 2.0f*c[2].z)*db + c[2].y;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[1].w*db + 2.0f*c[1].z)*db + c[1].y;\n";
                                out << nodeNames[j] << suffix << " = da*" << nodeNames[j] << suffix << " + (3.0f*c[0].w*db + 2.0f*c[0].z)*db + c[0].y;\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[7] << ";\n";
                            }
                            else
                                throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
292
                        }
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
                        out << "}\n";
                    }
                    else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                        out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n";
                        out << "real z = " << getTempName(node.getChildren()[2], temps) << suffix << ";\n";
                        out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
                        out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
                        out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
                        out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n";
                        out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
                        out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
                        out << "int u = min((int) floor(z), " << paramsInt[2] << "-1);\n";
                        out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
                        out << "float4 c[16];\n";
                        for (int j = 0; j < 16; j++)
                            out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
                        out << "real da = x-s;\n";
                        out << "real db = y-t;\n";
                        out << "real dc = z-u;\n";
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                                out << "real value[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = k + 4*m;
                                        out << "value[" << m << "] = db*value[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
                                    }
                                out << nodeNames[j] << suffix << " = value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));\n";
                            }
                            else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                                out << "real derivx[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = k + 4*m;
                                        out << "derivx[" << m << "] = db*derivx[" << m << "] + (3*c[" << base << "].w*da + 2*c[" << base << "].z)*da + c[" << base << "].y;\n";
                                    }
                                out << nodeNames[j] << suffix << " = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[9] << ";\n";
                            }
                            else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0) {
                                const string suffixes[] = {".x", ".y", ".z", ".w"};
                                out << "real derivy[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = 4*m;
                                        string suffix = suffixes[k];
                                        out << "derivy[" << m << "] = da*derivy[" << m << "] + (3*c[" << (base+3) << "]" << suffix << "*db + 2*c[" << (base+2) << "]" << suffix << ")*db + c[" << (base+1) << "]" << suffix << ";\n";
                                    }
                                out << nodeNames[j] << suffix << " = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[10] << ";\n";
                            }
                            else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1) {
                                out << "real derivz[4] = {0, 0, 0, 0};\n";
                                for (int k = 3; k >= 0; k--)
                                    for (int m = 0; m < 4; m++) {
                                        int base = k + 4*m;
                                        out << "derivz[" << m << "] = db*derivz[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
                                    }
                                out << nodeNames[j] << suffix << " = derivz[1] + dc*(2*derivz[2] + dc*3*derivz[3]);\n";
                                out << nodeNames[j] << suffix << " *= " << paramsFloat[11] << ";\n";
                            }
                            else
                                throw OpenMMException("Unsupported derivative order for Continuous3DFunction");
358
                        }
359
                        out << "}\n";
360
                    }
361
362
363
364
365
366
367
368
369
370
                    else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0) {
                                out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
                                out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
                                out << "int index = (int) floor(x+0.5f);\n";
                                out << nodeNames[j] << suffix << " = " << functionNames[i].second << "[index];\n";
                                out << "}\n";
                            }
371
                        }
372
                    }
373
374
375
376
377
378
379
380
381
382
383
384
                    else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0) {
                                out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << suffix << "+0.5f);\n";
                                out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << suffix << "+0.5f);\n";
                                out << "int xsize = (int) " << paramsInt[0] << ";\n";
                                out << "int ysize = (int) " << paramsInt[1] << ";\n";
                                out << "int index = x+y*xsize;\n";
                                out << "if (index >= 0 && index < xsize*ysize)\n";
                                out << nodeNames[j] << suffix << " = " << functionNames[i].second << "[index];\n";
                            }
385
                        }
386
                    }
387
388
389
390
391
392
393
394
395
396
397
398
399
400
                    else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
                        for (int j = 0; j < nodes.size(); j++) {
                            const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                            if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                                out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << suffix << "+0.5f);\n";
                                out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << suffix << "+0.5f);\n";
                                out << "int z = (int) floor(" << getTempName(node.getChildren()[2], temps) << suffix << "+0.5f);\n";
                                out << "int xsize = (int) " << paramsInt[0] << ";\n";
                                out << "int ysize = (int) " << paramsInt[1] << ";\n";
                                out << "int zsize = (int) " << paramsInt[2] << ";\n";
                                out << "int index = x+(y+z*ysize)*xsize;\n";
                                out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
                                out << nodeNames[j] << suffix << " = " << functionNames[i].second << "[index];\n";
                            }
401
                        }
402
                    }
403
                    out << "}\n";
peastman's avatar
peastman committed
404
405
                }
            }
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
            out << "}";
            break;
        }
        case Operation::ADD:
            out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps);
            break;
        case Operation::SUBTRACT:
            out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps);
            break;
        case Operation::MULTIPLY:
            out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
            break;
        case Operation::DIVIDE:
        {
            bool haveReciprocal = false;
421
422
423
424
425
426
427
428
429
430
431
432
433
            if (node.getChildren()[1].getOperation().getId() == Operation::RECIPROCAL) {
                for (int i = 0; i < (int) temps.size(); i++)
                    if (temps[i].first == node.getChildren()[1].getChildren()[1]) {
                        haveReciprocal = true;
                        out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second;
                    }
            }
            if (!haveReciprocal)
                for (int i = 0; i < (int) temps.size(); i++)
                    if (temps[i].first.getOperation().getId() == Operation::RECIPROCAL && temps[i].first.getChildren()[0] == node.getChildren()[1]) {
                        haveReciprocal = true;
                        out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second;
                    }
434
435
436
437
438
            if (!haveReciprocal)
                out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
            break;
        }
        case Operation::POWER:
439
            out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
440
441
442
443
444
            break;
        case Operation::NEGATE:
            out << "-" << getTempName(node.getChildren()[0], temps);
            break;
        case Operation::SQRT:
445
            out << "SQRT(" << getTempName(node.getChildren()[0], temps) << ")";
446
447
448
449
450
451
452
453
            break;
        case Operation::EXP:
            out << "EXP(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::LOG:
            out << "LOG(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::SIN:
454
            out << "SIN(" << getTempName(node.getChildren()[0], temps) << ")";
455
456
            break;
        case Operation::COS:
457
            out << "COS(" << getTempName(node.getChildren()[0], temps) << ")";
458
459
            break;
        case Operation::SEC:
460
            out << "RECIP(COS(" << getTempName(node.getChildren()[0], temps) << "))";
461
462
            break;
        case Operation::CSC:
463
            out << "RECIP(SIN(" << getTempName(node.getChildren()[0], temps) << "))";
464
465
            break;
        case Operation::TAN:
466
            out << "TAN(" << getTempName(node.getChildren()[0], temps) << ")";
467
468
            break;
        case Operation::COT:
469
            out << "RECIP(TAN(" << getTempName(node.getChildren()[0], temps) << "))";
470
471
            break;
        case Operation::ASIN:
472
            out << "ASIN(" << getTempName(node.getChildren()[0], temps) << ")";
473
474
            break;
        case Operation::ACOS:
475
            out << "ACOS(" << getTempName(node.getChildren()[0], temps) << ")";
476
477
            break;
        case Operation::ATAN:
478
            out << "ATAN(" << getTempName(node.getChildren()[0], temps) << ")";
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
            break;
        case Operation::SINH:
            out << "sinh(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::COSH:
            out << "cosh(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::TANH:
            out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::ERF:
            out << "erf(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::ERFC:
            out << "erfc(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::STEP:
            out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f";
            break;
        case Operation::DELTA:
            out << getTempName(node.getChildren()[0], temps) << " == 0.0f ? 1.0f : 0.0f";
            break;
        case Operation::SQUARE:
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg;
            break;
        }
        case Operation::CUBE:
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg << "*" << arg;
            break;
        }
        case Operation::RECIPROCAL:
            out << "RECIP(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::ADD_CONSTANT:
517
            out << context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
518
519
            break;
        case Operation::MULTIPLY_CONSTANT:
520
            out << context.doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps);
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
            break;
        case Operation::POWER_CONSTANT:
        {
            double exponent = dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue();
            if (exponent == 0.0)
                out << "1.0f";
            else if (exponent == (int) exponent) {
                out << "0.0f;\n";
                temps.push_back(make_pair(node, name));
                hasRecordedNode = true;

                // If multiple integral powers of the same base are needed, it's faster to calculate all of them
                // at once, so check to see if others are also needed.

                map<int, const ExpressionTreeNode*> powers;
                powers[(int) exponent] = &node;
                for (int j = 0; j < (int) allExpressions.size(); j++)
                    findRelatedPowers(node, allExpressions[j].getRootNode(), powers);
                vector<int> exponents;
                vector<string> names;
                vector<bool> hasAssigned(powers.size(), false);
                exponents.push_back((int) fabs(exponent));
                names.push_back(name);
peastman's avatar
peastman committed
544
545
546
                for (auto& power : powers) {
                    if (power.first != exponent) {
                        exponents.push_back(power.first >= 0 ? power.first : -power.first);
547
                        string name2 = prefix+context.intToString(temps.size());
548
                        names.push_back(name2);
peastman's avatar
peastman committed
549
                        temps.push_back(make_pair(*power.second, name2));
550
551
552
553
                        out << tempType << " " << name2 << " = 0.0f;\n";
                    }
                }
                out << "{\n";
554
                out << "real multiplier = " << (exponent < 0.0 ? "RECIP(" : "(") << getTempName(node.getChildren()[0], temps) << ");\n";
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
                bool done = false;
                while (!done) {
                    done = true;
                    for (int i = 0; i < (int) exponents.size(); i++) {
                        if (exponents[i]%2 == 1) {
                            if (!hasAssigned[i])
                                out << names[i] << " = multiplier;\n";
                            else
                                out << names[i] << " *= multiplier;\n";
                            hasAssigned[i] = true;
                        }
                        exponents[i] >>= 1;
                        if (exponents[i] != 0)
                            done = false;
                    }
                    if (!done)
                        out << "multiplier *= multiplier;\n";
                }
                out << "}";
            }
            else
576
                out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << context.doubleToString(exponent) << ")";
577
578
579
            break;
        }
        case Operation::MIN:
580
            out << "min((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
581
582
            break;
        case Operation::MAX:
583
            out << "max((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
584
585
586
587
            break;
        case Operation::ABS:
            out << "fabs(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
588
589
590
591
592
593
        case Operation::FLOOR:
            out << "floor(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::CEIL:
            out << "ceil(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
594
        case Operation::SELECT:
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
        {
            string compareVal = getTempName(node.getChildren()[0], temps);
            string val1 = getTempName(node.getChildren()[1], temps);
            string val2 = getTempName(node.getChildren()[2], temps);
            if (isVecType) {
                out << "make_" << tempType << "(0);\n";
                out << "{\n";
                out << tempType<<" tempCompareValue = " << compareVal << ";\n";
                out << name << ".x = (tempCompareValue.x != 0 ? " << val1 << ".x : " << val2 << ".x);\n";
                out << name << ".y = (tempCompareValue.y != 0 ? " << val1 << ".y : " << val2 << ".y);\n";
                out << name << ".z = (tempCompareValue.z != 0 ? " << val1 << ".z : " << val2 << ".z);\n";
                out << "}\n";
            }
            else
                out << "(" << compareVal << " != 0 ? " << val1 << " : " << val2 << ")";
610
            break;
611
        }
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        default:
            throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
    }
    out << ";\n";
    if (!hasRecordedNode)
        temps.push_back(make_pair(node, name));
}

string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector<pair<ExpressionTreeNode, string> >& temps) {
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return temps[i].second;
    stringstream out;
    out << "Internal error: No temporary variable for expression node: " << node;
    throw OpenMMException(out.str());
}

629
void CudaExpressionUtilities::findRelatedCustomFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
630
            vector<const Lepton::ExpressionTreeNode*>& nodes) {
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
    if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
        // Make sure the arguments are identical.
        
        for (int i = 0; i < (int) node.getChildren().size(); i++)
            if (node.getChildren()[i] != searchNode.getChildren()[i])
                return;
        
        // See if we already have an identical node.
        
        for (int i = 0; i < (int) nodes.size(); i++)
            if (*nodes[i] == searchNode)
                return;
        
        // Add the node.
        
646
        nodes.push_back(&searchNode);
647
    }
648
649
    else
        for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
650
            findRelatedCustomFunctions(node, searchNode.getChildren()[i], nodes);
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
}

void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
    if (searchNode.getOperation().getId() == Operation::POWER_CONSTANT && node.getChildren()[0] == searchNode.getChildren()[0]) {
        double realPower = dynamic_cast<const Operation::PowerConstant*>(&searchNode.getOperation())->getValue();
        int power = (int) realPower;
        if (power != realPower)
            return; // We are only interested in integer powers.
        if (powers.find(power) != powers.end())
            return; // This power is already in the map.
        if (powers.begin()->first*power < 0)
            return; // All powers must have the same sign.
        powers[power] = &searchNode;
    }
    else
        for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
            findRelatedPowers(node, searchNode.getChildren()[i], powers);
}

peastman's avatar
peastman committed
670
vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) {
671
    if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
peastman's avatar
peastman committed
672
673
        // Compute the spline coefficients.

674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
        const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function);
        vector<double> values;
        double min, max;
        fn.getFunctionParameters(values, min, max);
        int numValues = values.size();
        vector<double> x(numValues), derivs;
        for (int i = 0; i < numValues; i++)
            x[i] = min+i*(max-min)/(numValues-1);
        SplineFitter::createNaturalSpline(x, values, derivs);
        vector<float> f(4*(numValues-1));
        for (int i = 0; i < (int) values.size()-1; i++) {
            f[4*i] = (float) values[i];
            f[4*i+1] = (float) values[i+1];
            f[4*i+2] = (float) (derivs[i]/6.0);
            f[4*i+3] = (float) (derivs[i+1]/6.0);
        }
peastman's avatar
peastman committed
690
691
692
        width = 4;
        return f;
    }
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
    if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL) {
        // Compute the spline coefficients.

        const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(function);
        vector<double> values;
        int xsize, ysize;
        double xmin, xmax, ymin, ymax;
        fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
        vector<double> x(xsize), y(ysize);
        for (int i = 0; i < xsize; i++)
            x[i] = xmin+i*(xmax-xmin)/(xsize-1);
        for (int i = 0; i < ysize; i++)
            y[i] = ymin+i*(ymax-ymin)/(ysize-1);
        vector<vector<double> > c;
        SplineFitter::create2DNaturalSpline(x, y, values, c);
        vector<float> f(16*c.size());
        for (int i = 0; i < (int) c.size(); i++) {
            for (int j = 0; j < 16; j++)
                f[16*i+j] = (float) c[i][j];
        }
        width = 4;
        return f;
    }
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
    if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL) {
        // Compute the spline coefficients.

        const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(function);
        vector<double> values;
        int xsize, ysize, zsize;
        double xmin, xmax, ymin, ymax, zmin, zmax;
        fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
        vector<double> x(xsize), y(ysize), z(zsize);
        for (int i = 0; i < xsize; i++)
            x[i] = xmin+i*(xmax-xmin)/(xsize-1);
        for (int i = 0; i < ysize; i++)
            y[i] = ymin+i*(ymax-ymin)/(ysize-1);
        for (int i = 0; i < zsize; i++)
            z[i] = zmin+i*(zmax-zmin)/(zsize-1);
        vector<vector<double> > c;
        SplineFitter::create3DNaturalSpline(x, y, z, values, c);
        vector<float> f(64*c.size());
        for (int i = 0; i < (int) c.size(); i++) {
            for (int j = 0; j < 64; j++)
                f[64*i+j] = (float) c[i][j];
        }
        width = 4;
        return f;
    }
peastman's avatar
peastman committed
741
742
743
744
745
746
747
748
749
750
751
    if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL) {
        // Record the tabulated values.
        
        const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(function);
        vector<double> values;
        fn.getFunctionParameters(values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
752
753
        return f;
    }
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
    if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL) {
        // Record the tabulated values.
        
        const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(function);
        int xsize, ysize;
        vector<double> values;
        fn.getFunctionParameters(xsize, ysize, values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
        return f;
    }
    if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL) {
        // Record the tabulated values.
        
        const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(function);
        int xsize, ysize, zsize;
        vector<double> values;
        fn.getFunctionParameters(xsize, ysize, zsize, values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
        return f;
    }
782
783
784
    throw OpenMMException("computeFunctionCoefficients: Unknown function type");
}

785
786
vector<vector<double> > CudaExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
    vector<vector<double> > params(functions.size());
787
788
789
790
791
792
    for (int i = 0; i < (int) functions.size(); i++) {
        if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
            const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]);
            vector<double> values;
            double min, max;
            fn.getFunctionParameters(values, min, max);
793
794
795
796
            params[i].push_back(min);
            params[i].push_back(max);
            params[i].push_back((values.size()-1)/(max-min));
            params[i].push_back(values.size()-2);
797
        }
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
        else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
            const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(*functions[i]);
            vector<double> values;
            int xsize, ysize;
            double xmin, xmax, ymin, ymax;
            fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
            params[i].push_back(xsize-1);
            params[i].push_back(ysize-1);
            params[i].push_back(xmin);
            params[i].push_back(xmax);
            params[i].push_back(ymin);
            params[i].push_back(ymax);
            params[i].push_back((xsize-1)/(xmax-xmin));
            params[i].push_back((ysize-1)/(ymax-ymin));
        }
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
        else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
            const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(*functions[i]);
            vector<double> values;
            int xsize, ysize, zsize;
            double xmin, xmax, ymin, ymax, zmin, zmax;
            fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
            params[i].push_back(xsize-1);
            params[i].push_back(ysize-1);
            params[i].push_back(zsize-1);
            params[i].push_back(xmin);
            params[i].push_back(xmax);
            params[i].push_back(ymin);
            params[i].push_back(ymax);
            params[i].push_back(zmin);
            params[i].push_back(zmax);
            params[i].push_back((xsize-1)/(xmax-xmin));
            params[i].push_back((ysize-1)/(ymax-ymin));
            params[i].push_back((zsize-1)/(zmax-zmin));
        }
peastman's avatar
peastman committed
832
833
834
835
        else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
            const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
            vector<double> values;
            fn.getFunctionParameters(values);
836
            params[i].push_back(values.size());
peastman's avatar
peastman committed
837
        }
838
839
840
841
842
        else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
            const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
            int xsize, ysize;
            vector<double> values;
            fn.getFunctionParameters(xsize, ysize, values);
843
844
            params[i].push_back(xsize);
            params[i].push_back(ysize);
845
846
847
848
849
850
        }
        else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
            const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
            int xsize, ysize, zsize;
            vector<double> values;
            fn.getFunctionParameters(xsize, ysize, zsize, values);
851
852
853
            params[i].push_back(xsize);
            params[i].push_back(ysize);
            params[i].push_back(zsize);
854
        }
855
856
857
858
        else
            throw OpenMMException("computeFunctionParameters: Unknown function type");
    }
    return params;
859
}
860
861
862
863

Lepton::CustomFunction* CudaExpressionUtilities::getFunctionPlaceholder(const TabulatedFunction& function) {
    if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
        return &fp1;
864
865
    if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL)
        return &fp2;
866
867
    if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL)
        return &fp3;
868
869
870
871
872
873
874
875
    if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
        return &fp1;
    if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL)
        return &fp2;
    if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL)
        return &fp3;
    throw OpenMMException("getFunctionPlaceholder: Unknown function type");
}
876
877
878

Lepton::CustomFunction* CudaExpressionUtilities::getPeriodicDistancePlaceholder() {
    return &periodicDistance;
879
}