OpenCLExpressionUtilities.cpp 18.4 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2011 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h"
29
#include "openmm/internal/SplineFitter.h"
30
31
32
33
34
35
#include "lepton/Operation.h"

using namespace OpenMM;
using namespace Lepton;
using namespace std;

Peter Eastman's avatar
Peter Eastman committed
36
string OpenCLExpressionUtilities::doubleToString(double value) {
37
38
39
40
41
42
    stringstream s;
    s.precision(8);
    s << scientific << value << "f";
    return s.str();
}

Peter Eastman's avatar
Peter Eastman committed
43
string OpenCLExpressionUtilities::intToString(int value) {
44
45
46
47
48
    stringstream s;
    s << value;
    return s.str();
}

49
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
50
        const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) {
51
52
53
    vector<pair<ExpressionTreeNode, string> > variableNodes;
    for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
        variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
54
    return createExpressions(expressions, variableNodes, functions, prefix, functionParams, tempType);
55
56
57
}

string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
58
        const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) {
59
    stringstream out;
60
61
62
    vector<ParsedExpression> allExpressions;
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
        allExpressions.push_back(iter->second);
63
    vector<pair<ExpressionTreeNode, string> > temps = variables;
64
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
65
        processExpression(out, iter->second.getRootNode(), temps, functions, prefix, functionParams, allExpressions, tempType);
66
67
68
        out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
    }
    return out.str();
69
70
}

71
void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
72
        const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const vector<ParsedExpression>& allExpressions, const string& tempType) {
73
74
75
76
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return;
    for (int i = 0; i < (int) node.getChildren().size(); i++)
77
        processExpression(out, node.getChildren()[i], temps, functions, prefix, functionParams, allExpressions, tempType);
78
    string name = prefix+intToString(temps.size());
79
    bool hasRecordedNode = false;
80
    
81
    out << tempType << " " << name << " = ";
82
83
    switch (node.getOperation().getId()) {
        case Operation::CONSTANT:
84
85
            out << doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
            break;
86
        case Operation::VARIABLE:
87
            throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
88
89
90
91
92
93
94
        case Operation::CUSTOM:
        {
            int i;
            for (i = 0; i < (int) functions.size() && functions[i].first != node.getOperation().getName(); i++)
                ;
            if (i == functions.size())
                throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
95
            bool isDeriv = (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 1);
96
            out << "0.0f;\n";
97
98
99
100
101
102
103
104
105
106
107
108
109
110
            temps.push_back(make_pair(node, name));
            hasRecordedNode = true;

            // If both the value and derivative of the function are needed, it's faster to calculate them both
            // at once, so check to see if both are needed.

            const ExpressionTreeNode* valueNode = NULL;
            const ExpressionTreeNode* derivNode = NULL;
            for (int j = 0; j < (int) allExpressions.size(); j++)
                findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), valueNode, derivNode);
            string valueName = name;
            string derivName = name;
            if (valueNode != NULL && derivNode != NULL) {
                string name2 = prefix+intToString(temps.size());
111
                out << tempType << " " << name2 << " = 0.0f;\n";
112
113
114
115
116
117
118
119
120
                if (isDeriv) {
                    valueName = name2;
                    temps.push_back(make_pair(*valueNode, name2));
                }
                else {
                    derivName = name2;
                    temps.push_back(make_pair(*derivNode, name2));
                }
            }
121
122
123
124
            out << "{\n";
            out << "float4 params = " << functionParams << "[" << i << "];\n";
            out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n";
            out << "if (x >= params.x && x <= params.y) {\n";
125
126
            out << "x = (x-params.x)*params.z;\n";
            out << "int index = (int) (floor(x));\n";
127
            out << "index = min(index, (int) params.w);\n";
128
            out << "float4 coeff = " << functions[i].second << "[index];\n";
129
130
            out << "float b = x-index;\n";
            out << "float a = 1.0f-b;\n";
131
            if (valueNode != NULL)
132
                out << valueName << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n";
133
            if (derivNode != NULL)
134
                out << derivName << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n";
135
136
137
            out << "}\n";
            out << "}";
            break;
138
139
        }
        case Operation::ADD:
140
141
            out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps);
            break;
142
        case Operation::SUBTRACT:
143
144
            out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps);
            break;
145
        case Operation::MULTIPLY:
146
147
            out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
            break;
148
        case Operation::DIVIDE:
149
150
151
152
153
154
155
156
157
        {
            bool haveReciprocal = false;
            for (int i = 0; i < (int) temps.size(); i++)
                if (temps[i].first.getOperation().getId() == Operation::RECIPROCAL && temps[i].first.getChildren()[0] == node.getChildren()[1]) {
                    haveReciprocal = true;
                    out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second;
                }
            if (!haveReciprocal)
                out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
158
            break;
159
        }
160
        case Operation::POWER:
161
162
            out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
            break;
163
        case Operation::NEGATE:
164
165
            out << "-" << getTempName(node.getChildren()[0], temps);
            break;
166
        case Operation::SQRT:
167
168
            out << "sqrt(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
169
        case Operation::EXP:
170
            out << "EXP(" << getTempName(node.getChildren()[0], temps) << ")";
171
            break;
172
        case Operation::LOG:
173
            out << "LOG(" << getTempName(node.getChildren()[0], temps) << ")";
174
            break;
175
        case Operation::SIN:
176
177
            out << "sin(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
178
        case Operation::COS:
179
180
            out << "cos(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
181
        case Operation::SEC:
182
183
            out << "1.0f/cos(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
184
        case Operation::CSC:
185
186
            out << "1.0f/sin(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
187
        case Operation::TAN:
188
189
            out << "tan(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
190
        case Operation::COT:
191
192
            out << "1.0f/tan(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
193
        case Operation::ASIN:
194
195
            out << "asin(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
196
        case Operation::ACOS:
197
198
            out << "acos(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
199
        case Operation::ATAN:
200
201
            out << "atan(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
202
203
204
205
206
207
208
209
210
        case Operation::SINH:
            out << "sinh(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::COSH:
            out << "cosh(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::TANH:
            out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
211
212
213
214
215
216
        case Operation::ERF:
            out << "erf(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::ERFC:
            out << "erfc(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
217
218
219
        case Operation::STEP:
            out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f";
            break;
220
        case Operation::SQUARE:
221
222
223
224
225
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg;
            break;
        }
226
        case Operation::CUBE:
227
228
229
230
231
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg << "*" << arg;
            break;
        }
232
        case Operation::RECIPROCAL:
233
            out << "RECIP(" << getTempName(node.getChildren()[0], temps) << ")";
234
            break;
235
        case Operation::ADD_CONSTANT:
236
237
            out << doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
            break;
238
        case Operation::MULTIPLY_CONSTANT:
239
240
            out << doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps);
            break;
241
        case Operation::POWER_CONSTANT:
242
243
244
245
246
247
        {
            double exponent = dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue();
            if (exponent == 0.0)
                out << "1.0f";
            else if (exponent == (int) exponent) {
                out << "0.0f;\n";
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
                temps.push_back(make_pair(node, name));
                hasRecordedNode = true;

                // If multiple integral powers of the same base are needed, it's faster to calculate all of them
                // at once, so check to see if others are also needed.

                map<int, const ExpressionTreeNode*> powers;
                powers[(int) exponent] = &node;
                for (int j = 0; j < (int) allExpressions.size(); j++)
                    findRelatedPowers(node, allExpressions[j].getRootNode(), powers);
                vector<int> exponents;
                vector<string> names;
                vector<bool> hasAssigned(powers.size(), false);
                exponents.push_back((int) fabs(exponent));
                names.push_back(name);
                for (map<int, const ExpressionTreeNode*>::const_iterator iter = powers.begin(); iter != powers.end(); ++iter) {
                    if (iter->first != exponent) {
265
                        exponents.push_back(iter->first >= 0 ? iter->first : -iter->first);
266
267
268
                        string name2 = prefix+intToString(temps.size());
                        names.push_back(name2);
                        temps.push_back(make_pair(*iter->second, name2));
269
                        out << tempType << " " << name2 << " = 0.0f;\n";
270
271
                    }
                }
272
273
                out << "{\n";
                out << "float multiplier = " << (exponent < 0.0 ? "1.0f/" : "") << getTempName(node.getChildren()[0], temps) << ";\n";
274
275
276
                bool done = false;
                while (!done) {
                    done = true;
277
                    for (int i = 0; i < (int) exponents.size(); i++) {
278
279
280
281
282
283
284
285
286
287
                        if (exponents[i]%2 == 1) {
                            if (!hasAssigned[i])
                                out << names[i] << " = multiplier;\n";
                            else
                                out << names[i] << " *= multiplier;\n";
                            hasAssigned[i] = true;
                        }
                        exponents[i] >>= 1;
                        if (exponents[i] != 0)
                            done = false;
288
                    }
289
                    if (!done)
290
291
292
293
294
295
                        out << "multiplier *= multiplier;\n";
                }
                out << "}";
            }
            else
                out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << doubleToString(exponent) << ")";
296
            break;
297
        }
298
299
300
301
302
303
304
305
306
        case Operation::MIN:
            out << "min(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
            break;
        case Operation::MAX:
            out << "max(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
            break;
        case Operation::ABS:
            out << "fabs(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
307
308
        default:
            throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
309
    }
310
    out << ";\n";
311
312
    if (!hasRecordedNode)
        temps.push_back(make_pair(node, name));
313
314
315
316
317
318
319
320
321
}

string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector<pair<ExpressionTreeNode, string> >& temps) {
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return temps[i].second;
    stringstream out;
    out << "Internal error: No temporary variable for expression node: " << node;
    throw OpenMMException(out.str());
322
}
323
324
325
326
327
328
329
330
331
332
333
334
335

void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
            const ExpressionTreeNode*& valueNode, const ExpressionTreeNode*& derivNode) {
    if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0]) {
        if (dynamic_cast<const Operation::Custom*>(&searchNode.getOperation())->getDerivOrder()[0] == 0)
            valueNode = &searchNode;
        else
            derivNode = &searchNode;
    }
    else
        for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
            findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], valueNode, derivNode);
}
336
337
338

void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
    if (searchNode.getOperation().getId() == Operation::POWER_CONSTANT && node.getChildren()[0] == searchNode.getChildren()[0]) {
339
340
341
342
        double realPower = dynamic_cast<const Operation::PowerConstant*>(&searchNode.getOperation())->getValue();
        int power = (int) realPower;
        if (power != realPower)
            return; // We are only interested in integer powers.
343
344
345
346
347
348
349
350
351
352
        if (powers.find(power) != powers.end())
            return; // This power is already in the map.
        if (powers.begin()->first*power < 0)
            return; // All powers must have the same sign.
        powers[power] = &searchNode;
    }
    else
        for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
            findRelatedPowers(node, searchNode.getChildren()[i], powers);
}
353

354
355
vector<mm_float4> OpenCLExpressionUtilities::computeFunctionCoefficients(const vector<double>& values, double min, double max) {
    // Compute the spline coefficients.
356

357
358
359
360
361
362
363
364
    int numValues = values.size();
    vector<double> x(numValues), derivs;
    for (int i = 0; i < numValues; i++)
        x[i] = min+i*(max-min)/(numValues-1);
    SplineFitter::createNaturalSpline(x, values, derivs);
    vector<mm_float4> f(numValues-1);
    for (int i = 0; i < (int) values.size()-1; i++)
        f[i] = mm_float4((cl_float) values[i], (cl_float) values[i+1], (cl_float) (derivs[i]/6.0), (cl_float) (derivs[i+1]/6.0));
365
366
    return f;
}