OpenCLExpressionUtilities.cpp 26.2 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2014 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h"
29
#include "openmm/internal/SplineFitter.h"
30
31
32
33
34
35
#include "lepton/Operation.h"

using namespace OpenMM;
using namespace Lepton;
using namespace std;

36
37
38
OpenCLExpressionUtilities::OpenCLExpressionUtilities(OpenCLContext& context) : context(context), fp1(1), fp2(2), fp3(3) {
}

39
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
40
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
41
42
43
    vector<pair<ExpressionTreeNode, string> > variableNodes;
    for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
        variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
44
    return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
45
46
47
}

string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
48
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
49
    stringstream out;
50
51
52
    vector<ParsedExpression> allExpressions;
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
        allExpressions.push_back(iter->second);
53
    vector<pair<ExpressionTreeNode, string> > temps = variables;
54
    vector<vector<double> > functionParams = computeFunctionParameters(functions);
55
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
peastman's avatar
peastman committed
56
        processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
57
58
59
        out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
    }
    return out.str();
60
61
}

62
void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
63
        const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<vector<double> >& functionParams,
peastman's avatar
peastman committed
64
        const vector<ParsedExpression>& allExpressions, const string& tempType) {
65
66
67
68
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return;
    for (int i = 0; i < (int) node.getChildren().size(); i++)
peastman's avatar
peastman committed
69
        processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
70
    string name = prefix+context.intToString(temps.size());
71
    bool hasRecordedNode = false;
72
    
73
    out << tempType << " " << name << " = ";
74
75
    switch (node.getOperation().getId()) {
        case Operation::CONSTANT:
76
            out << context.doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
77
            break;
78
        case Operation::VARIABLE:
79
            throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
80
81
82
        case Operation::CUSTOM:
        {
            int i;
peastman's avatar
peastman committed
83
            for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
84
                ;
peastman's avatar
peastman committed
85
            if (i == functionNames.size())
86
87
                throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
            out << "0.0f;\n";
88
89
90
91
92
93
            temps.push_back(make_pair(node, name));
            hasRecordedNode = true;

            // If both the value and derivative of the function are needed, it's faster to calculate them both
            // at once, so check to see if both are needed.

94
            vector<const ExpressionTreeNode*> nodes;
95
            for (int j = 0; j < (int) allExpressions.size(); j++)
96
97
98
99
                findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), nodes);
            vector<string> nodeNames;
            nodeNames.push_back(name);
            for (int j = 1; j < (int) nodes.size(); j++) {
100
                string name2 = prefix+context.intToString(temps.size());
101
                out << tempType << " " << name2 << " = 0.0f;\n";
102
103
                nodeNames.push_back(name2);
                temps.push_back(make_pair(*nodes[j], name2));
104
            }
105
            out << "{\n";
106
107
108
109
110
            vector<string> paramsFloat, paramsInt;
            for (int j = 0; j < (int) functionParams[i].size(); j++) {
                paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
                paramsInt.push_back(context.intToString((int) functionParams[i][j]));
            }
peastman's avatar
peastman committed
111
112
            if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
                out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
113
114
                out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
                out << "x = (x-" << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
peastman's avatar
peastman committed
115
                out << "int index = (int) (floor(x));\n";
116
                out << "index = min(index, " << paramsInt[3] << ");\n";
peastman's avatar
peastman committed
117
118
119
                out << "float4 coeff = " << functionNames[i].second << "[index];\n";
                out << "real b = x-index;\n";
                out << "real a = 1.0f-b;\n";
120
121
122
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0)
123
                        out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
124
                    else
125
                        out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
126
                }
peastman's avatar
peastman committed
127
128
129
                out << "}\n";
            }
            else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
130
131
132
133
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0) {
                        out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
134
                        out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
135
136
137
138
139
140
141
142
143
144
145
146
                        out << "int index = (int) round(x);\n";
                        out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
                        out << "}\n";
                    }
                }
            }
            else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0 && derivOrder[1] == 0) {
                        out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
                        out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
147
148
                        out << "int xsize = " << paramsInt[0] << ";\n";
                        out << "int ysize = " << paramsInt[1] << ";\n";
149
150
151
152
153
154
155
156
157
158
159
160
161
                        out << "int index = x+y*xsize;\n";
                        out << "if (index >= 0 && index < xsize*ysize)\n";
                        out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
                    }
                }
            }
            else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
                for (int j = 0; j < nodes.size(); j++) {
                    const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
                    if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
                        out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
                        out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
                        out << "int z = (int) round(" << getTempName(node.getChildren()[2], temps) << ");\n";
162
163
164
                        out << "int xsize = " << paramsInt[0] << ";\n";
                        out << "int ysize = " << paramsInt[1] << ";\n";
                        out << "int zsize = " << paramsInt[2] << ";\n";
165
166
167
168
                        out << "int index = x+(y+z*ysize)*xsize;\n";
                        out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
                        out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
                    }
peastman's avatar
peastman committed
169
170
                }
            }
171
172
            out << "}";
            break;
173
174
        }
        case Operation::ADD:
175
176
            out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps);
            break;
177
        case Operation::SUBTRACT:
178
179
            out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps);
            break;
180
        case Operation::MULTIPLY:
181
182
            out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
            break;
183
        case Operation::DIVIDE:
184
185
186
187
188
189
190
191
192
        {
            bool haveReciprocal = false;
            for (int i = 0; i < (int) temps.size(); i++)
                if (temps[i].first.getOperation().getId() == Operation::RECIPROCAL && temps[i].first.getChildren()[0] == node.getChildren()[1]) {
                    haveReciprocal = true;
                    out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second;
                }
            if (!haveReciprocal)
                out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
193
            break;
194
        }
195
        case Operation::POWER:
196
197
            out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
            break;
198
        case Operation::NEGATE:
199
200
            out << "-" << getTempName(node.getChildren()[0], temps);
            break;
201
        case Operation::SQRT:
202
203
            out << "sqrt(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
204
        case Operation::EXP:
205
            out << "EXP(" << getTempName(node.getChildren()[0], temps) << ")";
206
            break;
207
        case Operation::LOG:
208
            out << "LOG(" << getTempName(node.getChildren()[0], temps) << ")";
209
            break;
210
        case Operation::SIN:
211
212
            out << "sin(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
213
        case Operation::COS:
214
215
            out << "cos(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
216
        case Operation::SEC:
217
218
            out << "1.0f/cos(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
219
        case Operation::CSC:
220
221
            out << "1.0f/sin(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
222
        case Operation::TAN:
223
224
            out << "tan(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
225
        case Operation::COT:
226
227
            out << "1.0f/tan(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
228
        case Operation::ASIN:
229
230
            out << "asin(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
231
        case Operation::ACOS:
232
233
            out << "acos(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
234
        case Operation::ATAN:
235
236
            out << "atan(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
237
238
239
240
241
242
243
244
245
        case Operation::SINH:
            out << "sinh(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::COSH:
            out << "cosh(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::TANH:
            out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
246
247
248
249
250
251
        case Operation::ERF:
            out << "erf(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::ERFC:
            out << "erfc(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
252
253
254
        case Operation::STEP:
            out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f";
            break;
255
256
257
        case Operation::DELTA:
            out << getTempName(node.getChildren()[0], temps) << " == 0.0f ? 1.0f : 0.0f";
            break;
258
        case Operation::SQUARE:
259
260
261
262
263
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg;
            break;
        }
264
        case Operation::CUBE:
265
266
267
268
269
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg << "*" << arg;
            break;
        }
270
        case Operation::RECIPROCAL:
271
            out << "RECIP(" << getTempName(node.getChildren()[0], temps) << ")";
272
            break;
273
        case Operation::ADD_CONSTANT:
274
            out << context.doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
275
            break;
276
        case Operation::MULTIPLY_CONSTANT:
277
            out << context.doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps);
278
            break;
279
        case Operation::POWER_CONSTANT:
280
281
282
283
284
285
        {
            double exponent = dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue();
            if (exponent == 0.0)
                out << "1.0f";
            else if (exponent == (int) exponent) {
                out << "0.0f;\n";
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
                temps.push_back(make_pair(node, name));
                hasRecordedNode = true;

                // If multiple integral powers of the same base are needed, it's faster to calculate all of them
                // at once, so check to see if others are also needed.

                map<int, const ExpressionTreeNode*> powers;
                powers[(int) exponent] = &node;
                for (int j = 0; j < (int) allExpressions.size(); j++)
                    findRelatedPowers(node, allExpressions[j].getRootNode(), powers);
                vector<int> exponents;
                vector<string> names;
                vector<bool> hasAssigned(powers.size(), false);
                exponents.push_back((int) fabs(exponent));
                names.push_back(name);
                for (map<int, const ExpressionTreeNode*>::const_iterator iter = powers.begin(); iter != powers.end(); ++iter) {
                    if (iter->first != exponent) {
303
                        exponents.push_back(iter->first >= 0 ? iter->first : -iter->first);
304
                        string name2 = prefix+context.intToString(temps.size());
305
306
                        names.push_back(name2);
                        temps.push_back(make_pair(*iter->second, name2));
307
                        out << tempType << " " << name2 << " = 0.0f;\n";
308
309
                    }
                }
310
311
                out << "{\n";
                out << "float multiplier = " << (exponent < 0.0 ? "1.0f/" : "") << getTempName(node.getChildren()[0], temps) << ";\n";
312
313
314
                bool done = false;
                while (!done) {
                    done = true;
315
                    for (int i = 0; i < (int) exponents.size(); i++) {
316
317
318
319
320
321
322
323
324
325
                        if (exponents[i]%2 == 1) {
                            if (!hasAssigned[i])
                                out << names[i] << " = multiplier;\n";
                            else
                                out << names[i] << " *= multiplier;\n";
                            hasAssigned[i] = true;
                        }
                        exponents[i] >>= 1;
                        if (exponents[i] != 0)
                            done = false;
326
                    }
327
                    if (!done)
328
329
330
331
332
                        out << "multiplier *= multiplier;\n";
                }
                out << "}";
            }
            else
333
                out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << context.doubleToString(exponent) << ")";
334
            break;
335
        }
336
337
338
339
340
341
342
343
344
        case Operation::MIN:
            out << "min(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
            break;
        case Operation::MAX:
            out << "max(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
            break;
        case Operation::ABS:
            out << "fabs(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
345
346
        default:
            throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
347
    }
348
    out << ";\n";
349
350
    if (!hasRecordedNode)
        temps.push_back(make_pair(node, name));
351
352
353
354
355
356
357
358
359
}

string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector<pair<ExpressionTreeNode, string> >& temps) {
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return temps[i].second;
    stringstream out;
    out << "Internal error: No temporary variable for expression node: " << node;
    throw OpenMMException(out.str());
360
}
361
362

void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
363
364
365
            vector<const Lepton::ExpressionTreeNode*>& nodes) {
    if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0])
        nodes.push_back(&searchNode);
366
367
    else
        for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
368
            findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes);
369
}
370
371
372

void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
    if (searchNode.getOperation().getId() == Operation::POWER_CONSTANT && node.getChildren()[0] == searchNode.getChildren()[0]) {
373
374
375
376
        double realPower = dynamic_cast<const Operation::PowerConstant*>(&searchNode.getOperation())->getValue();
        int power = (int) realPower;
        if (power != realPower)
            return; // We are only interested in integer powers.
377
378
379
380
381
382
383
384
385
386
        if (powers.find(power) != powers.end())
            return; // This power is already in the map.
        if (powers.begin()->first*power < 0)
            return; // All powers must have the same sign.
        powers[power] = &searchNode;
    }
    else
        for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
            findRelatedPowers(node, searchNode.getChildren()[i], powers);
}
387

peastman's avatar
peastman committed
388
vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) {
389
    if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
peastman's avatar
peastman committed
390
391
        // Compute the spline coefficients.

392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function);
        vector<double> values;
        double min, max;
        fn.getFunctionParameters(values, min, max);
        int numValues = values.size();
        vector<double> x(numValues), derivs;
        for (int i = 0; i < numValues; i++)
            x[i] = min+i*(max-min)/(numValues-1);
        SplineFitter::createNaturalSpline(x, values, derivs);
        vector<float> f(4*(numValues-1));
        for (int i = 0; i < (int) values.size()-1; i++) {
            f[4*i] = (float) values[i];
            f[4*i+1] = (float) values[i+1];
            f[4*i+2] = (float) (derivs[i]/6.0);
            f[4*i+3] = (float) (derivs[i+1]/6.0);
        }
peastman's avatar
peastman committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
        width = 4;
        return f;
    }
    if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL) {
        // Record the tabulated values.
        
        const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(function);
        vector<double> values;
        fn.getFunctionParameters(values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
422
423
        return f;
    }
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL) {
        // Record the tabulated values.
        
        const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(function);
        int xsize, ysize;
        vector<double> values;
        fn.getFunctionParameters(xsize, ysize, values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
        return f;
    }
    if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL) {
        // Record the tabulated values.
        
        const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(function);
        int xsize, ysize, zsize;
        vector<double> values;
        fn.getFunctionParameters(xsize, ysize, zsize, values);
        int numValues = values.size();
        vector<float> f(numValues);
        for (int i = 0; i < numValues; i++)
            f[i] = (float) values[i];
        width = 1;
        return f;
    }
452
453
454
    throw OpenMMException("computeFunctionCoefficients: Unknown function type");
}

455
456
vector<vector<double> > OpenCLExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
    vector<vector<double> > params(functions.size());
457
458
459
460
461
462
    for (int i = 0; i < (int) functions.size(); i++) {
        if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
            const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]);
            vector<double> values;
            double min, max;
            fn.getFunctionParameters(values, min, max);
463
464
465
466
            params[i].push_back(min);
            params[i].push_back(max);
            params[i].push_back((values.size()-1)/(max-min));
            params[i].push_back(values.size()-2);
467
        }
peastman's avatar
peastman committed
468
469
470
471
        else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
            const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
            vector<double> values;
            fn.getFunctionParameters(values);
472
            params[i].push_back(values.size());
peastman's avatar
peastman committed
473
        }
474
475
476
477
478
        else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
            const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
            int xsize, ysize;
            vector<double> values;
            fn.getFunctionParameters(xsize, ysize, values);
479
480
            params[i].push_back(xsize);
            params[i].push_back(ysize);
481
482
483
484
485
486
        }
        else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
            const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
            int xsize, ysize, zsize;
            vector<double> values;
            fn.getFunctionParameters(xsize, ysize, zsize, values);
487
488
489
            params[i].push_back(xsize);
            params[i].push_back(ysize);
            params[i].push_back(zsize);
490
        }
491
492
493
494
        else
            throw OpenMMException("computeFunctionParameters: Unknown function type");
    }
    return params;
495
}
496
497
498
499
500
501
502
503
504
505
506
507

Lepton::CustomFunction* OpenCLExpressionUtilities::getFunctionPlaceholder(const TabulatedFunction& function) {
    if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
        return &fp1;
    if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
        return &fp1;
    if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL)
        return &fp2;
    if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL)
        return &fp3;
    throw OpenMMException("getFunctionPlaceholder: Unknown function type");
}