/* -------------------------------------------------------------------------- * * OpenMM * * -------------------------------------------------------------------------- * * This is part of the OpenMM molecular simulation toolkit originating from * * Simbios, the NIH National Center for Physics-Based Simulation of * * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * * Portions copyright (c) 2009-2016 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * * This program is free software: you can redistribute it and/or modify * * it under the terms of the GNU Lesser General Public License as published * * by the Free Software Foundation, either version 3 of the License, or * * (at your option) any later version. * * * * This program is distributed in the hope that it will be useful, * * but WITHOUT ANY WARRANTY; without even the implied warranty of * * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * * GNU Lesser General Public License for more details. * * * * You should have received a copy of the GNU Lesser General Public License * * along with this program. If not, see . * * -------------------------------------------------------------------------- */ #include "CudaExpressionUtilities.h" #include "openmm/OpenMMException.h" #include "openmm/internal/SplineFitter.h" #include "lepton/Operation.h" using namespace OpenMM; using namespace Lepton; using namespace std; CudaExpressionUtilities::CudaExpressionUtilities(CudaContext& context) : context(context), fp1(1), fp2(2), fp3(3), periodicDistance(6) { } string CudaExpressionUtilities::createExpressions(const map& expressions, const map& variables, const vector& functions, const vector >& functionNames, const string& prefix, const string& tempType) { vector > variableNodes; for (map::const_iterator iter = variables.begin(); iter != variables.end(); ++iter) variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second)); return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType); } string CudaExpressionUtilities::createExpressions(const map& expressions, const vector >& variables, const vector& functions, const vector >& functionNames, const string& prefix, const string& tempType) { stringstream out; vector allExpressions; for (map::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) allExpressions.push_back(iter->second); vector > temps = variables; vector > functionParams = computeFunctionParameters(functions); for (map::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) { processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType); out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n"; } return out.str(); } void CudaExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector >& temps, const vector& functions, const vector >& functionNames, const string& prefix, const vector >& functionParams, const vector& allExpressions, const string& tempType) { for (int i = 0; i < (int) temps.size(); i++) if (temps[i].first == node) return; for (int i = 0; i < (int) node.getChildren().size(); i++) processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType); string name = prefix+context.intToString(temps.size()); bool hasRecordedNode = false; out << tempType << " " << name << " = "; switch (node.getOperation().getId()) { case Operation::CONSTANT: out << context.doubleToString(dynamic_cast(&node.getOperation())->getValue()); break; case Operation::VARIABLE: throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName()); case Operation::CUSTOM: { out << "0.0f;\n"; temps.push_back(make_pair(node, name)); hasRecordedNode = true; // If both the value and derivative of the function are needed, it's faster to calculate them both // at once, so check to see if both are needed. vector nodes; for (int j = 0; j < (int) allExpressions.size(); j++) findRelatedCustomFunctions(node, allExpressions[j].getRootNode(), nodes); vector nodeNames; nodeNames.push_back(name); for (int j = 1; j < (int) nodes.size(); j++) { string name2 = prefix+context.intToString(temps.size()); out << tempType << " " << name2 << " = 0.0f;\n"; nodeNames.push_back(name2); temps.push_back(make_pair(*nodes[j], name2)); } out << "{\n"; if (node.getOperation().getName() == "periodicdistance") { // This is the periodicdistance() function. out << tempType << "3 periodicDistance_delta = make_real3("; for (int i = 0; i < 3; i++) { if (i > 0) out << ", "; out << getTempName(node.getChildren()[i], temps) << "-" << getTempName(node.getChildren()[i+3], temps); } out << ");\n"; out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n"; out << tempType << " periodicDistance_r2 = periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z;\n"; out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_r2);\n"; for (int j = 0; j < nodes.size(); j++) { const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder(); int argIndex = -1; for (int k = 0; k < 6; k++) { if (derivOrder[k] > 0) { if (derivOrder[k] > 1 || argIndex != -1) throw OpenMMException("Unsupported derivative of periodicdistance"); // Should be impossible for this to happen. argIndex = k; } } if (argIndex == -1) out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n"; else if (argIndex == 0) out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.x*periodicDistance_rinv : 0);\n"; else if (argIndex == 1) out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.y*periodicDistance_rinv : 0);\n"; else if (argIndex == 2) out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? periodicDistance_delta.z*periodicDistance_rinv : 0);\n"; else if (argIndex == 3) out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.x*periodicDistance_rinv : 0);\n"; else if (argIndex == 4) out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.y*periodicDistance_rinv : 0);\n"; else if (argIndex == 5) out << nodeNames[j] << " = (periodicDistance_r2 > 0 ? -periodicDistance_delta.z*periodicDistance_rinv : 0);\n"; } } else { // This is a tabulated function. int i; for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++) ; if (i == functionNames.size()) throw OpenMMException("Unknown function in expression: "+node.getOperation().getName()); vector paramsFloat, paramsInt; for (int j = 0; j < (int) functionParams[i].size(); j++) { paramsFloat.push_back(context.doubleToString(functionParams[i][j])); paramsInt.push_back(context.intToString((int) functionParams[i][j])); } if (dynamic_cast(functions[i]) != NULL) { out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n"; out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n"; out << "int index = (int) (floor(x));\n"; out << "index = min(index, (int) " << paramsInt[3] << ");\n"; out << "float4 coeff = " << functionNames[i].second << "[index];\n"; out << "real b = x-index;\n"; out << "real a = 1.0f-b;\n"; for (int j = 0; j < nodes.size(); j++) { const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder(); if (derivOrder[0] == 0) out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n"; else out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n"; } out << "}\n"; } else if (dynamic_cast(functions[i]) != NULL) { out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n"; out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n"; out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n"; out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n"; out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n"; out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n"; out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n"; out << "float4 c[4];\n"; for (int j = 0; j < 4; j++) out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n"; out << "real da = x-s;\n"; out << "real db = y-t;\n"; for (int j = 0; j < nodes.size(); j++) { const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder(); if (derivOrder[0] == 0 && derivOrder[1] == 0) { out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[3].w*db + c[3].z)*db + c[3].y)*db + c[3].x;\n"; out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[2].w*db + c[2].z)*db + c[2].y)*db + c[2].x;\n"; out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[1].w*db + c[1].z)*db + c[1].y)*db + c[1].x;\n"; out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[0].w*db + c[0].z)*db + c[0].y)*db + c[0].x;\n"; } else if (derivOrder[0] == 1 && derivOrder[1] == 0) { out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].w*da + 2.0f*c[2].w)*da + c[1].w;\n"; out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].z*da + 2.0f*c[2].z)*da + c[1].z;\n"; out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].y*da + 2.0f*c[2].y)*da + c[1].y;\n"; out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].x*da + 2.0f*c[2].x)*da + c[1].x;\n"; out << nodeNames[j] << " *= " << paramsFloat[6] << ";\n"; } else if (derivOrder[0] == 0 && derivOrder[1] == 1) { out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[3].w*db + 2.0f*c[3].z)*db + c[3].y;\n"; out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[2].w*db + 2.0f*c[2].z)*db + c[2].y;\n"; out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[1].w*db + 2.0f*c[1].z)*db + c[1].y;\n"; out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[0].w*db + 2.0f*c[0].z)*db + c[0].y;\n"; out << nodeNames[j] << " *= " << paramsFloat[7] << ";\n"; } else throw OpenMMException("Unsupported derivative order for Continuous2DFunction"); } out << "}\n"; } else if (dynamic_cast(functions[i]) != NULL) { out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n"; out << "real z = " << getTempName(node.getChildren()[2], temps) << ";\n"; out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n"; out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n"; out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n"; out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n"; out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n"; out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n"; out << "int u = min((int) floor(z), " << paramsInt[2] << "-1);\n"; out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n"; out << "float4 c[16];\n"; for (int j = 0; j < 16; j++) out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n"; out << "real da = x-s;\n"; out << "real db = y-t;\n"; out << "real dc = z-u;\n"; for (int j = 0; j < nodes.size(); j++) { const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder(); if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) { out << "real value[4] = {0, 0, 0, 0};\n"; for (int k = 3; k >= 0; k--) for (int m = 0; m < 4; m++) { int base = k + 4*m; out << "value[" << m << "] = db*value[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n"; } out << nodeNames[j] << " = value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));\n"; } else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0) { out << "real derivx[4] = {0, 0, 0, 0};\n"; for (int k = 3; k >= 0; k--) for (int m = 0; m < 4; m++) { int base = k + 4*m; out << "derivx[" << m << "] = db*derivx[" << m << "] + (3*c[" << base << "].w*da + 2*c[" << base << "].z)*da + c[" << base << "].y;\n"; } out << nodeNames[j] << " = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));\n"; out << nodeNames[j] << " *= " << paramsFloat[9] << ";\n"; } else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0) { const string suffixes[] = {".x", ".y", ".z", ".w"}; out << "real derivy[4] = {0, 0, 0, 0};\n"; for (int k = 3; k >= 0; k--) for (int m = 0; m < 4; m++) { int base = 4*m; string suffix = suffixes[k]; out << "derivy[" << m << "] = da*derivy[" << m << "] + (3*c[" << (base+3) << "]" << suffix << "*db + 2*c[" << (base+2) << "]" << suffix << ")*db + c[" << (base+1) << "]" << suffix << ";\n"; } out << nodeNames[j] << " = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));\n"; out << nodeNames[j] << " *= " << paramsFloat[10] << ";\n"; } else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1) { out << "real derivz[4] = {0, 0, 0, 0};\n"; for (int k = 3; k >= 0; k--) for (int m = 0; m < 4; m++) { int base = k + 4*m; out << "derivz[" << m << "] = db*derivz[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n"; } out << nodeNames[j] << " = derivz[1] + dc*(2*derivz[2] + dc*3*derivz[3]);\n"; out << nodeNames[j] << " *= " << paramsFloat[11] << ";\n"; } else throw OpenMMException("Unsupported derivative order for Continuous3DFunction"); } out << "}\n"; } else if (dynamic_cast(functions[i]) != NULL) { for (int j = 0; j < nodes.size(); j++) { const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder(); if (derivOrder[0] == 0) { out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n"; out << "int index = (int) floor(x+0.5f);\n"; out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n"; out << "}\n"; } } } else if (dynamic_cast(functions[i]) != NULL) { for (int j = 0; j < nodes.size(); j++) { const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder(); if (derivOrder[0] == 0 && derivOrder[1] == 0) { out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << "+0.5f);\n"; out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << "+0.5f);\n"; out << "int xsize = (int) " << paramsInt[0] << ";\n"; out << "int ysize = (int) " << paramsInt[1] << ";\n"; out << "int index = x+y*xsize;\n"; out << "if (index >= 0 && index < xsize*ysize)\n"; out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n"; } } } else if (dynamic_cast(functions[i]) != NULL) { for (int j = 0; j < nodes.size(); j++) { const vector& derivOrder = dynamic_cast(&nodes[j]->getOperation())->getDerivOrder(); if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) { out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << "+0.5f);\n"; out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << "+0.5f);\n"; out << "int z = (int) floor(" << getTempName(node.getChildren()[2], temps) << "+0.5f);\n"; out << "int xsize = (int) " << paramsInt[0] << ";\n"; out << "int ysize = (int) " << paramsInt[1] << ";\n"; out << "int zsize = (int) " << paramsInt[2] << ";\n"; out << "int index = x+(y+z*ysize)*xsize;\n"; out << "if (index >= 0 && index < xsize*ysize*zsize)\n"; out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n"; } } } } out << "}"; break; } case Operation::ADD: out << getTempName(node.getChildren()[0], temps) << "+" << getTempName(node.getChildren()[1], temps); break; case Operation::SUBTRACT: out << getTempName(node.getChildren()[0], temps) << "-" << getTempName(node.getChildren()[1], temps); break; case Operation::MULTIPLY: out << getTempName(node.getChildren()[0], temps) << "*" << getTempName(node.getChildren()[1], temps); break; case Operation::DIVIDE: { bool haveReciprocal = false; if (node.getChildren()[1].getOperation().getId() == Operation::RECIPROCAL) { for (int i = 0; i < (int) temps.size(); i++) if (temps[i].first == node.getChildren()[1].getChildren()[1]) { haveReciprocal = true; out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second; } } if (!haveReciprocal) for (int i = 0; i < (int) temps.size(); i++) if (temps[i].first.getOperation().getId() == Operation::RECIPROCAL && temps[i].first.getChildren()[0] == node.getChildren()[1]) { haveReciprocal = true; out << getTempName(node.getChildren()[0], temps) << "*" << temps[i].second; } if (!haveReciprocal) out << getTempName(node.getChildren()[0], temps) << "/" << getTempName(node.getChildren()[1], temps); break; } case Operation::POWER: out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")"; break; case Operation::NEGATE: out << "-" << getTempName(node.getChildren()[0], temps); break; case Operation::SQRT: out << "SQRT(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::EXP: out << "EXP(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::LOG: out << "LOG(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::SIN: out << "SIN(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::COS: out << "COS(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::SEC: out << "RECIP(COS(" << getTempName(node.getChildren()[0], temps) << "))"; break; case Operation::CSC: out << "RECIP(SIN(" << getTempName(node.getChildren()[0], temps) << "))"; break; case Operation::TAN: out << "TAN(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::COT: out << "RECIP(TAN(" << getTempName(node.getChildren()[0], temps) << "))"; break; case Operation::ASIN: out << "ASIN(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::ACOS: out << "ACOS(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::ATAN: out << "ATAN(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::SINH: out << "sinh(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::COSH: out << "cosh(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::TANH: out << "tanh(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::ERF: out << "erf(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::ERFC: out << "erfc(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::STEP: out << getTempName(node.getChildren()[0], temps) << " >= 0.0f ? 1.0f : 0.0f"; break; case Operation::DELTA: out << getTempName(node.getChildren()[0], temps) << " == 0.0f ? 1.0f : 0.0f"; break; case Operation::SQUARE: { string arg = getTempName(node.getChildren()[0], temps); out << arg << "*" << arg; break; } case Operation::CUBE: { string arg = getTempName(node.getChildren()[0], temps); out << arg << "*" << arg << "*" << arg; break; } case Operation::RECIPROCAL: out << "RECIP(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::ADD_CONSTANT: out << context.doubleToString(dynamic_cast(&node.getOperation())->getValue()) << "+" << getTempName(node.getChildren()[0], temps); break; case Operation::MULTIPLY_CONSTANT: out << context.doubleToString(dynamic_cast(&node.getOperation())->getValue()) << "*" << getTempName(node.getChildren()[0], temps); break; case Operation::POWER_CONSTANT: { double exponent = dynamic_cast(&node.getOperation())->getValue(); if (exponent == 0.0) out << "1.0f"; else if (exponent == (int) exponent) { out << "0.0f;\n"; temps.push_back(make_pair(node, name)); hasRecordedNode = true; // If multiple integral powers of the same base are needed, it's faster to calculate all of them // at once, so check to see if others are also needed. map powers; powers[(int) exponent] = &node; for (int j = 0; j < (int) allExpressions.size(); j++) findRelatedPowers(node, allExpressions[j].getRootNode(), powers); vector exponents; vector names; vector hasAssigned(powers.size(), false); exponents.push_back((int) fabs(exponent)); names.push_back(name); for (map::const_iterator iter = powers.begin(); iter != powers.end(); ++iter) { if (iter->first != exponent) { exponents.push_back(iter->first >= 0 ? iter->first : -iter->first); string name2 = prefix+context.intToString(temps.size()); names.push_back(name2); temps.push_back(make_pair(*iter->second, name2)); out << tempType << " " << name2 << " = 0.0f;\n"; } } out << "{\n"; out << "real multiplier = " << (exponent < 0.0 ? "RECIP(" : "(") << getTempName(node.getChildren()[0], temps) << ");\n"; bool done = false; while (!done) { done = true; for (int i = 0; i < (int) exponents.size(); i++) { if (exponents[i]%2 == 1) { if (!hasAssigned[i]) out << names[i] << " = multiplier;\n"; else out << names[i] << " *= multiplier;\n"; hasAssigned[i] = true; } exponents[i] >>= 1; if (exponents[i] != 0) done = false; } if (!done) out << "multiplier *= multiplier;\n"; } out << "}"; } else out << "pow((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << context.doubleToString(exponent) << ")"; break; } case Operation::MIN: out << "min((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")"; break; case Operation::MAX: out << "max((" << tempType << ") " << getTempName(node.getChildren()[0], temps) << ", (" << tempType << ") " << getTempName(node.getChildren()[1], temps) << ")"; break; case Operation::ABS: out << "fabs(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::FLOOR: out << "floor(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::CEIL: out << "ceil(" << getTempName(node.getChildren()[0], temps) << ")"; break; case Operation::SELECT: out << "(" << getTempName(node.getChildren()[0], temps) << " != 0 ? " << getTempName(node.getChildren()[1], temps) << " : " << getTempName(node.getChildren()[2], temps) << ")"; break; default: throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName()); } out << ";\n"; if (!hasRecordedNode) temps.push_back(make_pair(node, name)); } string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, const vector >& temps) { for (int i = 0; i < (int) temps.size(); i++) if (temps[i].first == node) return temps[i].second; stringstream out; out << "Internal error: No temporary variable for expression node: " << node; throw OpenMMException(out.str()); } void CudaExpressionUtilities::findRelatedCustomFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, vector& nodes) { if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) { // Make sure the arguments are identical. for (int i = 0; i < (int) node.getChildren().size(); i++) if (node.getChildren()[i] != searchNode.getChildren()[i]) return; // See if we already have an identical node. for (int i = 0; i < (int) nodes.size(); i++) if (*nodes[i] == searchNode) return; // Add the node. nodes.push_back(&searchNode); } else for (int i = 0; i < (int) searchNode.getChildren().size(); i++) findRelatedCustomFunctions(node, searchNode.getChildren()[i], nodes); } void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map& powers) { if (searchNode.getOperation().getId() == Operation::POWER_CONSTANT && node.getChildren()[0] == searchNode.getChildren()[0]) { double realPower = dynamic_cast(&searchNode.getOperation())->getValue(); int power = (int) realPower; if (power != realPower) return; // We are only interested in integer powers. if (powers.find(power) != powers.end()) return; // This power is already in the map. if (powers.begin()->first*power < 0) return; // All powers must have the same sign. powers[power] = &searchNode; } else for (int i = 0; i < (int) searchNode.getChildren().size(); i++) findRelatedPowers(node, searchNode.getChildren()[i], powers); } vector CudaExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) { if (dynamic_cast(&function) != NULL) { // Compute the spline coefficients. const Continuous1DFunction& fn = dynamic_cast(function); vector values; double min, max; fn.getFunctionParameters(values, min, max); int numValues = values.size(); vector x(numValues), derivs; for (int i = 0; i < numValues; i++) x[i] = min+i*(max-min)/(numValues-1); SplineFitter::createNaturalSpline(x, values, derivs); vector f(4*(numValues-1)); for (int i = 0; i < (int) values.size()-1; i++) { f[4*i] = (float) values[i]; f[4*i+1] = (float) values[i+1]; f[4*i+2] = (float) (derivs[i]/6.0); f[4*i+3] = (float) (derivs[i+1]/6.0); } width = 4; return f; } if (dynamic_cast(&function) != NULL) { // Compute the spline coefficients. const Continuous2DFunction& fn = dynamic_cast(function); vector values; int xsize, ysize; double xmin, xmax, ymin, ymax; fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax); vector x(xsize), y(ysize); for (int i = 0; i < xsize; i++) x[i] = xmin+i*(xmax-xmin)/(xsize-1); for (int i = 0; i < ysize; i++) y[i] = ymin+i*(ymax-ymin)/(ysize-1); vector > c; SplineFitter::create2DNaturalSpline(x, y, values, c); vector f(16*c.size()); for (int i = 0; i < (int) c.size(); i++) { for (int j = 0; j < 16; j++) f[16*i+j] = (float) c[i][j]; } width = 4; return f; } if (dynamic_cast(&function) != NULL) { // Compute the spline coefficients. const Continuous3DFunction& fn = dynamic_cast(function); vector values; int xsize, ysize, zsize; double xmin, xmax, ymin, ymax, zmin, zmax; fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax); vector x(xsize), y(ysize), z(zsize); for (int i = 0; i < xsize; i++) x[i] = xmin+i*(xmax-xmin)/(xsize-1); for (int i = 0; i < ysize; i++) y[i] = ymin+i*(ymax-ymin)/(ysize-1); for (int i = 0; i < zsize; i++) z[i] = zmin+i*(zmax-zmin)/(zsize-1); vector > c; SplineFitter::create3DNaturalSpline(x, y, z, values, c); vector f(64*c.size()); for (int i = 0; i < (int) c.size(); i++) { for (int j = 0; j < 64; j++) f[64*i+j] = (float) c[i][j]; } width = 4; return f; } if (dynamic_cast(&function) != NULL) { // Record the tabulated values. const Discrete1DFunction& fn = dynamic_cast(function); vector values; fn.getFunctionParameters(values); int numValues = values.size(); vector f(numValues); for (int i = 0; i < numValues; i++) f[i] = (float) values[i]; width = 1; return f; } if (dynamic_cast(&function) != NULL) { // Record the tabulated values. const Discrete2DFunction& fn = dynamic_cast(function); int xsize, ysize; vector values; fn.getFunctionParameters(xsize, ysize, values); int numValues = values.size(); vector f(numValues); for (int i = 0; i < numValues; i++) f[i] = (float) values[i]; width = 1; return f; } if (dynamic_cast(&function) != NULL) { // Record the tabulated values. const Discrete3DFunction& fn = dynamic_cast(function); int xsize, ysize, zsize; vector values; fn.getFunctionParameters(xsize, ysize, zsize, values); int numValues = values.size(); vector f(numValues); for (int i = 0; i < numValues; i++) f[i] = (float) values[i]; width = 1; return f; } throw OpenMMException("computeFunctionCoefficients: Unknown function type"); } vector > CudaExpressionUtilities::computeFunctionParameters(const vector& functions) { vector > params(functions.size()); for (int i = 0; i < (int) functions.size(); i++) { if (dynamic_cast(functions[i]) != NULL) { const Continuous1DFunction& fn = dynamic_cast(*functions[i]); vector values; double min, max; fn.getFunctionParameters(values, min, max); params[i].push_back(min); params[i].push_back(max); params[i].push_back((values.size()-1)/(max-min)); params[i].push_back(values.size()-2); } else if (dynamic_cast(functions[i]) != NULL) { const Continuous2DFunction& fn = dynamic_cast(*functions[i]); vector values; int xsize, ysize; double xmin, xmax, ymin, ymax; fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax); params[i].push_back(xsize-1); params[i].push_back(ysize-1); params[i].push_back(xmin); params[i].push_back(xmax); params[i].push_back(ymin); params[i].push_back(ymax); params[i].push_back((xsize-1)/(xmax-xmin)); params[i].push_back((ysize-1)/(ymax-ymin)); } else if (dynamic_cast(functions[i]) != NULL) { const Continuous3DFunction& fn = dynamic_cast(*functions[i]); vector values; int xsize, ysize, zsize; double xmin, xmax, ymin, ymax, zmin, zmax; fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax); params[i].push_back(xsize-1); params[i].push_back(ysize-1); params[i].push_back(zsize-1); params[i].push_back(xmin); params[i].push_back(xmax); params[i].push_back(ymin); params[i].push_back(ymax); params[i].push_back(zmin); params[i].push_back(zmax); params[i].push_back((xsize-1)/(xmax-xmin)); params[i].push_back((ysize-1)/(ymax-ymin)); params[i].push_back((zsize-1)/(zmax-zmin)); } else if (dynamic_cast(functions[i]) != NULL) { const Discrete1DFunction& fn = dynamic_cast(*functions[i]); vector values; fn.getFunctionParameters(values); params[i].push_back(values.size()); } else if (dynamic_cast(functions[i]) != NULL) { const Discrete2DFunction& fn = dynamic_cast(*functions[i]); int xsize, ysize; vector values; fn.getFunctionParameters(xsize, ysize, values); params[i].push_back(xsize); params[i].push_back(ysize); } else if (dynamic_cast(functions[i]) != NULL) { const Discrete3DFunction& fn = dynamic_cast(*functions[i]); int xsize, ysize, zsize; vector values; fn.getFunctionParameters(xsize, ysize, zsize, values); params[i].push_back(xsize); params[i].push_back(ysize); params[i].push_back(zsize); } else throw OpenMMException("computeFunctionParameters: Unknown function type"); } return params; } Lepton::CustomFunction* CudaExpressionUtilities::getFunctionPlaceholder(const TabulatedFunction& function) { if (dynamic_cast(&function) != NULL) return &fp1; if (dynamic_cast(&function) != NULL) return &fp2; if (dynamic_cast(&function) != NULL) return &fp3; if (dynamic_cast(&function) != NULL) return &fp1; if (dynamic_cast(&function) != NULL) return &fp2; if (dynamic_cast(&function) != NULL) return &fp3; throw OpenMMException("getFunctionPlaceholder: Unknown function type"); } Lepton::CustomFunction* CudaExpressionUtilities::getPeriodicDistancePlaceholder() { return &periodicDistance; }