OpenCLExpressionUtilities.cpp 11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
 * Portions copyright (c) 2009 Stanford University and the Authors.           *
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h"
#include "lepton/Operation.h"

using namespace OpenMM;
using namespace Lepton;
using namespace std;

static string doubleToString(double value) {
    stringstream s;
    s.precision(8);
    s << scientific << value << "f";
    return s.str();
}

static string intToString(int value) {
    stringstream s;
    s << value;
    return s.str();
}

48
49
50
51
52
53
54
55
56
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
        const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
    stringstream out;
    vector<pair<ExpressionTreeNode, string> > temps;
    for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
        processExpression(out, iter->second.getRootNode(), temps, variables, functions, prefix, functionParams);
        out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
    }
    return out.str();
57
58
}

59
60
61
62
63
64
65
66
67
68
void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
        const map<string, string>& variables, const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return;
    for (int i = 0; i < (int) node.getChildren().size(); i++)
        processExpression(out, node.getChildren()[i], temps, variables, functions, prefix, functionParams);
    string name = prefix+intToString(temps.size());
    
    out << "float " << name << " = ";
69
70
    switch (node.getOperation().getId()) {
        case Operation::CONSTANT:
71
72
            out << doubleToString(dynamic_cast<const Operation::Constant*>(&node.getOperation())->getValue());
            break;
73
74
75
76
77
        case Operation::VARIABLE:
        {
            map<string, string>::const_iterator iter = variables.find(node.getOperation().getName());
            if (iter == variables.end())
                throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
            out << iter->second;
            break;
        }
        case Operation::CUSTOM:
        {
            int i;
            for (i = 0; i < (int) functions.size() && functions[i].first != node.getOperation().getName(); i++)
                ;
            if (i == functions.size())
                throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
            out << "0.0f;\n";
            out << "{\n";
            out << "float4 params = " << functionParams << "[" << i << "];\n";
            out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n";
            out << "if (x >= params.x && x <= params.y) {\n";
            out << "int index = (int) (floor((x-params.x)*params.z));\n";
            out << "float4 coeff = " << functions[i].second << "[index];\n";
            out << "x = (x-params.x)*params.z-index;\n";
            if (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 0)
                out << name << " = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));\n";
            else
                out << name << " = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;\n";
            out << "}\n";
            out << "}";
            break;
103
104
        }
        case Operation::ADD:
105
106
            out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps);
            break;
107
        case Operation::SUBTRACT:
108
109
            out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps);
            break;
110
        case Operation::MULTIPLY:
111
112
            out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
            break;
113
        case Operation::DIVIDE:
114
115
            out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
            break;
116
        case Operation::POWER:
117
118
            out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << getTempName(node.getChildren()[1], temps) << ")";
            break;
119
        case Operation::NEGATE:
120
121
            out << "-" << getTempName(node.getChildren()[0], temps);
            break;
122
        case Operation::SQRT:
123
124
            out << "sqrt(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
125
        case Operation::EXP:
126
127
            out << "exp(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
128
        case Operation::LOG:
129
130
            out << "log(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
131
        case Operation::SIN:
132
133
            out << "sin(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
134
        case Operation::COS:
135
136
            out << "cos(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
137
        case Operation::SEC:
138
139
            out << "1.0f/cos(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
140
        case Operation::CSC:
141
142
            out << "1.0f/sin(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
143
        case Operation::TAN:
144
145
            out << "tan(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
146
        case Operation::COT:
147
148
            out << "1.0f/tan(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
149
        case Operation::ASIN:
150
151
            out << "asin(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
152
        case Operation::ACOS:
153
154
            out << "acos(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
155
        case Operation::ATAN:
156
157
            out << "atan(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
158
159
160
161
162
163
164
165
166
        case Operation::SINH:
            out << "sinh(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::COSH:
            out << "cosh(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
        case Operation::TANH:
            out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")";
            break;
167
        case Operation::SQUARE:
168
169
170
171
172
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg;
            break;
        }
173
        case Operation::CUBE:
174
175
176
177
178
        {
            string arg = getTempName(node.getChildren()[0], temps);
            out << arg << "*" << arg << "*" << arg;
            break;
        }
179
        case Operation::RECIPROCAL:
180
181
            out << "1.0f/" << getTempName(node.getChildren()[0], temps);
            break;
182
        case Operation::ADD_CONSTANT:
183
184
            out << doubleToString(dynamic_cast<const Operation::AddConstant*>(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
            break;
185
        case Operation::MULTIPLY_CONSTANT:
186
187
            out << doubleToString(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps);
            break;
188
        case Operation::POWER_CONSTANT:
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        {
            double exponent = dynamic_cast<const Operation::PowerConstant*>(&node.getOperation())->getValue();
            if (exponent == 0.0)
                out << "1.0f";
            else if (exponent == (int) exponent) {
                out << "0.0f;\n";
                out << "{\n";
                out << "float multiplier = " << (exponent < 0.0 ? "1.0f/" : "") << getTempName(node.getChildren()[0], temps) << ";\n";
                int exp = (int) fabs(exponent);
                bool hasAssigned = false;
                while (exp != 0) {
                    if (exp%2 == 1) {
                        if (!hasAssigned)
                            out << name << " = multiplier;\n";
                        else
                            out << name << " *= multiplier;\n";
                        hasAssigned = true;
                    }
                    exp >>= 1;
                    if (exp != 0)
                        out << "multiplier *= multiplier;\n";
                }
                out << "}";
            }
            else
                out << "pow(" << getTempName(node.getChildren()[0], temps) << ", " << doubleToString(exponent) << ")";
215
            break;
216
        }
217
218
        default:
            throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
219
    }
220
221
222
223
224
225
226
227
228
229
230
    out << ";\n";
    temps.push_back(make_pair(node, name));
}

string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector<pair<ExpressionTreeNode, string> >& temps) {
    for (int i = 0; i < (int) temps.size(); i++)
        if (temps[i].first == node)
            return temps[i].second;
    stringstream out;
    out << "Internal error: No temporary variable for expression node: " << node;
    throw OpenMMException(out.str());
231
}