/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see . *
* -------------------------------------------------------------------------- */
#include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h"
#include "openmm/internal/SplineFitter.h"
#include "lepton/Operation.h"
using namespace OpenMM;
using namespace Lepton;
using namespace std;
OpenCLExpressionUtilities::OpenCLExpressionUtilities(OpenCLContext& context) : context(context), fp1(1), fp2(2), fp3(3), periodicDistance(6) {
}
string OpenCLExpressionUtilities::createExpressions(const map& expressions, const map& variables,
const vector& functions, const vector >& functionNames, const string& prefix, const string& tempType) {
vector > variableNodes;
for (map::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
}
string OpenCLExpressionUtilities::createExpressions(const map& expressions, const vector >& variables,
const vector& functions, const vector >& functionNames, const string& prefix, const string& tempType) {
stringstream out;
vector allExpressions;
for (map::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second);
vector > temps = variables;
vector > functionParams = computeFunctionParameters(functions);
for (map::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
}
return out.str();
}
void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector >& temps,
const vector& functions, const vector >& functionNames, const string& prefix, const vector >& functionParams,
const vector& allExpressions, const string& tempType) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return;
for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
string name = prefix+context.intToString(temps.size());
bool hasRecordedNode = false;
out << tempType << " " << name << " = ";
switch (node.getOperation().getId()) {
case Operation::CONSTANT:
out << context.doubleToString(dynamic_cast(&node.getOperation())->getValue());
break;
case Operation::VARIABLE:
throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
case Operation::CUSTOM:
{
out << "0.0f;\n";
temps.push_back(make_pair(node, name));
hasRecordedNode = true;
// If both the value and derivative of the function are needed, it's faster to calculate them both
// at once, so check to see if both are needed.
vector nodes;
for (int j = 0; j < (int) allExpressions.size(); j++)
findRelatedCustomFunctions(node, allExpressions[j].getRootNode(), nodes);
vector nodeNames;
nodeNames.push_back(name);
for (int j = 1; j < (int) nodes.size(); j++) {
string name2 = prefix+context.intToString(temps.size());
out << tempType << " " << name2 << " = 0.0f;\n";
nodeNames.push_back(name2);
temps.push_back(make_pair(*nodes[j], name2));
}
out << "{\n";
if (node.getOperation().getName() == "periodicdistance") {
// This is the periodicdistance() function.
out << tempType << "3 periodicDistance_delta = (real3) (";
for (int i = 0; i < 3; i++) {
if (i > 0)
out << ", ";
out << getTempName(node.getChildren()[i], temps) << "-" << getTempName(node.getChildren()[i+3], temps);
}
out << ");\n";
out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
out << tempType << " periodicDistance_r2 = periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z;\n";
out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_r2);\n";
for (int j = 0; j < nodes.size(); j++) {
const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder();
int argIndex = -1;
for (int k = 0; k < 6; k++) {
if (derivOrder[k] > 0) {
if (derivOrder[k] > 1 || argIndex != -1)
throw OpenMMException("Unsupported derivative of periodicdistance"); // Should be impossible for this to happen.
argIndex = k;
}
}
if (argIndex == -1)
out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n";
else if (argIndex == 0)
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
else if (argIndex == 1)
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
else if (argIndex == 2)
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
else if (argIndex == 3)
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.x*periodicDistance_rinv : 0);\n";
else if (argIndex == 4)
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.y*periodicDistance_rinv : 0);\n";
else if (argIndex == 5)
out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.z*periodicDistance_rinv : 0);\n";
}
}
else {
// This is a tabulated function.
int i;
for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
;
if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
vector paramsFloat, paramsInt;
for (int j = 0; j < (int) functionParams[i].size(); j++) {
paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
paramsInt.push_back(context.intToString((int) functionParams[i][j]));
}
if (dynamic_cast(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0)
out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
else
out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
}
out << "}\n";
}
else if (dynamic_cast(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
out << "float4 c[4];\n";
for (int j = 0; j < 4; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[3].w*db + c[3].z)*db + c[3].y)*db + c[3].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[2].w*db + c[2].z)*db + c[2].y)*db + c[2].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[1].w*db + c[1].z)*db + c[1].y)*db + c[1].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[0].w*db + c[0].z)*db + c[0].y)*db + c[0].x;\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0) {
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].w*da + 2.0f*c[2].w)*da + c[1].w;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].z*da + 2.0f*c[2].z)*da + c[1].z;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].y*da + 2.0f*c[2].y)*da + c[1].y;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].x*da + 2.0f*c[2].x)*da + c[1].x;\n";
out << nodeNames[j] << " *= " << paramsFloat[6] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[3].w*db + 2.0f*c[3].z)*db + c[3].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[2].w*db + 2.0f*c[2].z)*db + c[2].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[1].w*db + 2.0f*c[1].z)*db + c[1].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[0].w*db + 2.0f*c[0].z)*db + c[0].y;\n";
out << nodeNames[j] << " *= " << paramsFloat[7] << ";\n";
}
else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
}
out << "}\n";
}
else if (dynamic_cast(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "real z = " << getTempName(node.getChildren()[2], temps) << ";\n";
out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << "int u = min((int) floor(z), " << paramsInt[2] << ");\n";
out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
out << "float4 c[16];\n";
for (int j = 0; j < 16; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
out << "real dc = z-u;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real value[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "value[" << m << "] = db*value[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
}
out << nodeNames[j] << " = value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real derivx[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivx[" << m << "] = db*derivx[" << m << "] + (3*c[" << base << "].w*da + 2*c[" << base << "].z)*da + c[" << base << "].y;\n";
}
out << nodeNames[j] << " = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[9] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0) {
const string suffixes[] = {".x", ".y", ".z", ".w"};
out << "real derivy[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = 4*m;
string suffix = suffixes[m];
out << "derivy[" << m << "] = da*derivy[" << m << "] + (3*c[" << (base+3) << "]" << suffix << "*db + 2*c[" << (base+2) << "]" << suffix << ")*db + c[" << (base+1) << "]" << suffix << ";\n";
}
out << nodeNames[j] << " = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[10] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1) {
out << "real derivz[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivz[" << m << "] = db*derivz[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
}
out << nodeNames[j] << " = derivz[1] + dc*(2*derivz[2] + dc*3*derivz[3]);\n";
out << nodeNames[j] << " *= " << paramsFloat[11] << ";\n";
}
else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
}
out << "}\n";
}
else if (dynamic_cast(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
out << "int index = (int) floor(x+0.5f);\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
}
}
}
else if (dynamic_cast(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << "+0.5f);\n";
out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << "+0.5f);\n";
out << "int xsize = " << paramsInt[0] << ";\n";
out << "int ysize = " << paramsInt[1] << ";\n";
out << "int index = x+y*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
}
}
}
else if (dynamic_cast(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << "+0.5f);\n";
out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << "+0.5f);\n";
out << "int z = (int) floor(" << getTempName(node.getChildren()[2], temps) << "+0.5f);\n";
out << "int xsize = " << paramsInt[0] << ";\n";
out << "int ysize = " << paramsInt[1] << ";\n";
out << "int zsize = " << paramsInt[2] << ";\n";
out << "int index = x+(y+z*ysize)*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
}
}
}
}
out << "}";
break;
}
case Operation::ADD:
out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps);
break;
case Operation::SUBTRACT:
out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps);
break;
case Operation::MULTIPLY:
out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps);
break;
case Operation::DIVIDE:
{
bool haveReciprocal = false;
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first.getOperation().getId() == Operation::RECIPROCAL && temps[i].first.getChildren()[0] == node.getChildren()[1]) {
haveReciprocal = true;
out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second;
}
if (!haveReciprocal)
out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps);
break;
}
case Operation::POWER:
out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
break;
case Operation::NEGATE:
out << "-" << getTempName(node.getChildren()[0], temps);
break;
case Operation::SQRT:
out << "sqrt(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::EXP:
out << "EXP(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::LOG:
out << "LOG(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SIN:
out << "sin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::COS:
out << "cos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SEC:
out << "1.0f/cos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::CSC:
out << "1.0f/sin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::TAN:
out << "tan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::COT:
out << "1.0f/tan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ASIN:
out << "asin(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ACOS:
out << "acos(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ATAN:
out << "atan(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SINH:
out << "sinh(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::COSH:
out << "cosh(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::TANH:
out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ERF:
out << "erf(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ERFC:
out << "erfc(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::STEP:
out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f";
break;
case Operation::DELTA:
out << getTempName(node.getChildren()[0], temps) << " == 0.0f ? 1.0f : 0.0f";
break;
case Operation::SQUARE:
{
string arg = getTempName(node.getChildren()[0], temps);
out << arg << "*" << arg;
break;
}
case Operation::CUBE:
{
string arg = getTempName(node.getChildren()[0], temps);
out << arg << "*" << arg << "*" << arg;
break;
}
case Operation::RECIPROCAL:
out << "RECIP(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::ADD_CONSTANT:
out << context.doubleToString(dynamic_cast(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps);
break;
case Operation::MULTIPLY_CONSTANT:
out << context.doubleToString(dynamic_cast(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps);
break;
case Operation::POWER_CONSTANT:
{
double exponent = dynamic_cast(&node.getOperation())->getValue();
if (exponent == 0.0)
out << "1.0f";
else if (exponent == (int) exponent) {
out << "0.0f;\n";
temps.push_back(make_pair(node, name));
hasRecordedNode = true;
// If multiple integral powers of the same base are needed, it's faster to calculate all of them
// at once, so check to see if others are also needed.
map powers;
powers[(int) exponent] = &node;
for (int j = 0; j < (int) allExpressions.size(); j++)
findRelatedPowers(node, allExpressions[j].getRootNode(), powers);
vector exponents;
vector names;
vector hasAssigned(powers.size(), false);
exponents.push_back((int) fabs(exponent));
names.push_back(name);
for (map::const_iterator iter = powers.begin(); iter != powers.end(); ++iter) {
if (iter->first != exponent) {
exponents.push_back(iter->first >= 0 ? iter->first : -iter->first);
string name2 = prefix+context.intToString(temps.size());
names.push_back(name2);
temps.push_back(make_pair(*iter->second, name2));
out << tempType << " " << name2 << " = 0.0f;\n";
}
}
out << "{\n";
out << "float multiplier = " << (exponent < 0.0 ? "1.0f/" : "") << getTempName(node.getChildren()[0], temps) << ";\n";
bool done = false;
while (!done) {
done = true;
for (int i = 0; i < (int) exponents.size(); i++) {
if (exponents[i]%2 == 1) {
if (!hasAssigned[i])
out << names[i] << " = multiplier;\n";
else
out << names[i] << " *= multiplier;\n";
hasAssigned[i] = true;
}
exponents[i] >>= 1;
if (exponents[i] != 0)
done = false;
}
if (!done)
out << "multiplier *= multiplier;\n";
}
out << "}";
}
else
out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << context.doubleToString(exponent) << ")";
break;
}
case Operation::MIN:
out << "min((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
break;
case Operation::MAX:
out << "max((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")";
break;
case Operation::ABS:
out << "fabs(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::FLOOR:
out << "floor(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::CEIL:
out << "ceil(" << getTempName(node.getChildren()[0], temps) << ")";
break;
case Operation::SELECT:
out << "(" << getTempName(node.getChildren()[0], temps) << " != 0 ? " << getTempName(node.getChildren()[1], temps) << " : " << getTempName(node.getChildren()[2], temps) << ")";
break;
default:
throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
}
out << ";\n";
if (!hasRecordedNode)
temps.push_back(make_pair(node, name));
}
string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector >& temps) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return temps[i].second;
stringstream out;
out << "Internal error: No temporary variable for expression node: " << node;
throw OpenMMException(out.str());
}
void OpenCLExpressionUtilities::findRelatedCustomFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
vector& nodes) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
// Make sure the arguments are identical.
for (int i = 0; i < (int) node.getChildren().size(); i++)
if (node.getChildren()[i] != searchNode.getChildren()[i])
return;
// See if we already have an identical node.
for (int i = 0; i < (int) nodes.size(); i++)
if (*nodes[i] == searchNode)
return;
// Add the node.
nodes.push_back(&searchNode);
}
else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedCustomFunctions(node, searchNode.getChildren()[i], nodes);
}
void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map& powers) {
if (searchNode.getOperation().getId() == Operation::POWER_CONSTANT && node.getChildren()[0] == searchNode.getChildren()[0]) {
double realPower = dynamic_cast(&searchNode.getOperation())->getValue();
int power = (int) realPower;
if (power != realPower)
return; // We are only interested in integer powers.
if (powers.find(power) != powers.end())
return; // This power is already in the map.
if (powers.begin()->first*power < 0)
return; // All powers must have the same sign.
powers[power] = &searchNode;
}
else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedPowers(node, searchNode.getChildren()[i], powers);
}
vector OpenCLExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) {
if (dynamic_cast(&function) != NULL) {
// Compute the spline coefficients.
const Continuous1DFunction& fn = dynamic_cast(function);
vector values;
double min, max;
fn.getFunctionParameters(values, min, max);
int numValues = values.size();
vector x(numValues), derivs;
for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createNaturalSpline(x, values, derivs);
vector f(4*(numValues-1));
for (int i = 0; i < (int) values.size()-1; i++) {
f[4*i] = (float) values[i];
f[4*i+1] = (float) values[i+1];
f[4*i+2] = (float) (derivs[i]/6.0);
f[4*i+3] = (float) (derivs[i+1]/6.0);
}
width = 4;
return f;
}
if (dynamic_cast(&function) != NULL) {
// Compute the spline coefficients.
const Continuous2DFunction& fn = dynamic_cast(function);
vector values;
int xsize, ysize;
double xmin, xmax, ymin, ymax;
fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
vector x(xsize), y(ysize);
for (int i = 0; i < xsize; i++)
x[i] = xmin+i*(xmax-xmin)/(xsize-1);
for (int i = 0; i < ysize; i++)
y[i] = ymin+i*(ymax-ymin)/(ysize-1);
vector > c;
SplineFitter::create2DNaturalSpline(x, y, values, c);
vector f(16*c.size());
for (int i = 0; i < (int) c.size(); i++) {
for (int j = 0; j < 16; j++)
f[16*i+j] = (float) c[i][j];
}
width = 4;
return f;
}
if (dynamic_cast(&function) != NULL) {
// Compute the spline coefficients.
const Continuous3DFunction& fn = dynamic_cast(function);
vector values;
int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax;
fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
vector x(xsize), y(ysize), z(zsize);
for (int i = 0; i < xsize; i++)
x[i] = xmin+i*(xmax-xmin)/(xsize-1);
for (int i = 0; i < ysize; i++)
y[i] = ymin+i*(ymax-ymin)/(ysize-1);
for (int i = 0; i < zsize; i++)
z[i] = zmin+i*(zmax-zmin)/(zsize-1);
vector > c;
SplineFitter::create3DNaturalSpline(x, y, z, values, c);
vector f(64*c.size());
for (int i = 0; i < (int) c.size(); i++) {
for (int j = 0; j < 64; j++)
f[64*i+j] = (float) c[i][j];
}
width = 4;
return f;
}
if (dynamic_cast(&function) != NULL) {
// Record the tabulated values.
const Discrete1DFunction& fn = dynamic_cast(function);
vector values;
fn.getFunctionParameters(values);
int numValues = values.size();
vector f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
if (dynamic_cast(&function) != NULL) {
// Record the tabulated values.
const Discrete2DFunction& fn = dynamic_cast(function);
int xsize, ysize;
vector values;
fn.getFunctionParameters(xsize, ysize, values);
int numValues = values.size();
vector f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
if (dynamic_cast(&function) != NULL) {
// Record the tabulated values.
const Discrete3DFunction& fn = dynamic_cast(function);
int xsize, ysize, zsize;
vector values;
fn.getFunctionParameters(xsize, ysize, zsize, values);
int numValues = values.size();
vector f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
throw OpenMMException("computeFunctionCoefficients: Unknown function type");
}
vector > OpenCLExpressionUtilities::computeFunctionParameters(const vector& functions) {
vector > params(functions.size());
for (int i = 0; i < (int) functions.size(); i++) {
if (dynamic_cast(functions[i]) != NULL) {
const Continuous1DFunction& fn = dynamic_cast(*functions[i]);
vector values;
double min, max;
fn.getFunctionParameters(values, min, max);
params[i].push_back(min);
params[i].push_back(max);
params[i].push_back((values.size()-1)/(max-min));
params[i].push_back(values.size()-2);
}
else if (dynamic_cast(functions[i]) != NULL) {
const Continuous2DFunction& fn = dynamic_cast(*functions[i]);
vector values;
int xsize, ysize;
double xmin, xmax, ymin, ymax;
fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
params[i].push_back(xsize-1);
params[i].push_back(ysize-1);
params[i].push_back(xmin);
params[i].push_back(xmax);
params[i].push_back(ymin);
params[i].push_back(ymax);
params[i].push_back((xsize-1)/(xmax-xmin));
params[i].push_back((ysize-1)/(ymax-ymin));
}
else if (dynamic_cast(functions[i]) != NULL) {
const Continuous3DFunction& fn = dynamic_cast(*functions[i]);
vector values;
int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax;
fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
params[i].push_back(xsize-1);
params[i].push_back(ysize-1);
params[i].push_back(zsize-1);
params[i].push_back(xmin);
params[i].push_back(xmax);
params[i].push_back(ymin);
params[i].push_back(ymax);
params[i].push_back(zmin);
params[i].push_back(zmax);
params[i].push_back((xsize-1)/(xmax-xmin));
params[i].push_back((ysize-1)/(ymax-ymin));
params[i].push_back((zsize-1)/(zmax-zmin));
}
else if (dynamic_cast(functions[i]) != NULL) {
const Discrete1DFunction& fn = dynamic_cast(*functions[i]);
vector values;
fn.getFunctionParameters(values);
params[i].push_back(values.size());
}
else if (dynamic_cast(functions[i]) != NULL) {
const Discrete2DFunction& fn = dynamic_cast(*functions[i]);
int xsize, ysize;
vector values;
fn.getFunctionParameters(xsize, ysize, values);
params[i].push_back(xsize);
params[i].push_back(ysize);
}
else if (dynamic_cast(functions[i]) != NULL) {
const Discrete3DFunction& fn = dynamic_cast(*functions[i]);
int xsize, ysize, zsize;
vector values;
fn.getFunctionParameters(xsize, ysize, zsize, values);
params[i].push_back(xsize);
params[i].push_back(ysize);
params[i].push_back(zsize);
}
else
throw OpenMMException("computeFunctionParameters: Unknown function type");
}
return params;
}
Lepton::CustomFunction* OpenCLExpressionUtilities::getFunctionPlaceholder(const TabulatedFunction& function) {
if (dynamic_cast(&function) != NULL)
return &fp1;
if (dynamic_cast(&function) != NULL)
return &fp2;
if (dynamic_cast(&function) != NULL)
return &fp3;
if (dynamic_cast(&function) != NULL)
return &fp1;
if (dynamic_cast(&function) != NULL)
return &fp2;
if (dynamic_cast(&function) != NULL)
return &fp3;
throw OpenMMException("getFunctionPlaceholder: Unknown function type");
}
Lepton::CustomFunction* OpenCLExpressionUtilities::getPeriodicDistancePlaceholder() {
return &periodicDistance;
}