Commit eaef52d9 authored by peastman's avatar peastman
Browse files

Created OpenCL implementation of periodicdistance()

parent 91a8cc49
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2014 Stanford University and the Authors. * * Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -89,6 +89,10 @@ public: ...@@ -89,6 +89,10 @@ public:
* @param function the function for which to get a placeholder * @param function the function for which to get a placeholder
*/ */
Lepton::CustomFunction* getFunctionPlaceholder(const TabulatedFunction& function); Lepton::CustomFunction* getFunctionPlaceholder(const TabulatedFunction& function);
/**
* Get a Lepton::CustomFunction that can be used to represent the periodicdistance() function when parsing expressions.
*/
Lepton::CustomFunction* getPeriodicDistancePlaceholder();
private: private:
class FunctionPlaceholder : public Lepton::CustomFunction { class FunctionPlaceholder : public Lepton::CustomFunction {
public: public:
...@@ -114,13 +118,13 @@ private: ...@@ -114,13 +118,13 @@ private:
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType); const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps); std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedCustomFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::vector<const Lepton::ExpressionTreeNode*>& nodes); std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers); std::map<int, const Lepton::ExpressionTreeNode*>& powers);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions); std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
OpenCLContext& context; OpenCLContext& context;
FunctionPlaceholder fp1, fp2, fp3; FunctionPlaceholder fp1, fp2, fp3, periodicDistance;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -181,7 +181,7 @@ void OpenCLBondedUtilities::initialize(const System& system) { ...@@ -181,7 +181,7 @@ void OpenCLBondedUtilities::initialize(const System& system) {
for (int i = 0; i < (int) prefixCode.size(); i++) for (int i = 0; i < (int) prefixCode.size(); i++)
s<<prefixCode[i]; s<<prefixCode[i];
string bufferType = (context.getSupports64BitGlobalAtomics() ? "long" : "real4"); string bufferType = (context.getSupports64BitGlobalAtomics() ? "long" : "real4");
s<<"__kernel void computeBondedForces(__global "<<bufferType<<"* restrict forceBuffers, __global real* restrict energyBuffer, __global const real4* restrict posq, int groups"; s<<"__kernel void computeBondedForces(__global "<<bufferType<<"* restrict forceBuffers, __global real* restrict energyBuffer, __global const real4* restrict posq, int groups, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ";
for (int i = 0; i < setSize; i++) { for (int i = 0; i < setSize; i++) {
int force = set[i]; int force = set[i];
string indexType = "uint"+(indexWidth[force] == 1 ? "" : context.intToString(indexWidth[force])); string indexType = "uint"+(indexWidth[force] == 1 ? "" : context.intToString(indexWidth[force]));
...@@ -267,7 +267,7 @@ void OpenCLBondedUtilities::computeInteractions(int groups) { ...@@ -267,7 +267,7 @@ void OpenCLBondedUtilities::computeInteractions(int groups) {
kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
index++; index += 6;
for (int j = 0; j < (int) forceSets[i].size(); j++) { for (int j = 0; j < (int) forceSets[i].size(); j++) {
kernel.setArg<cl::Buffer>(index++, atomIndices[forceSets[i][j]]->getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, atomIndices[forceSets[i][j]]->getDeviceBuffer());
kernel.setArg<cl::Buffer>(index++, bufferIndices[forceSets[i][j]]->getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, bufferIndices[forceSets[i][j]]->getDeviceBuffer());
...@@ -277,7 +277,22 @@ void OpenCLBondedUtilities::computeInteractions(int groups) { ...@@ -277,7 +277,22 @@ void OpenCLBondedUtilities::computeInteractions(int groups) {
} }
} }
for (int i = 0; i < (int) kernels.size(); i++) { for (int i = 0; i < (int) kernels.size(); i++) {
kernels[i].setArg<cl_int>(3, groups); cl::Kernel& kernel = kernels[i];
kernel.setArg<cl_int>(3, groups);
if (context.getUseDoublePrecision()) {
kernel.setArg<mm_double4>(4, context.getPeriodicBoxSizeDouble());
kernel.setArg<mm_double4>(5, context.getInvPeriodicBoxSizeDouble());
kernel.setArg<mm_double4>(6, context.getPeriodicBoxVecXDouble());
kernel.setArg<mm_double4>(7, context.getPeriodicBoxVecYDouble());
kernel.setArg<mm_double4>(8, context.getPeriodicBoxVecZDouble());
}
else {
kernel.setArg<mm_float4>(4, context.getPeriodicBoxSize());
kernel.setArg<mm_float4>(5, context.getInvPeriodicBoxSize());
kernel.setArg<mm_float4>(6, context.getPeriodicBoxVecX());
kernel.setArg<mm_float4>(7, context.getPeriodicBoxVecY());
kernel.setArg<mm_float4>(8, context.getPeriodicBoxVecZ());
}
context.executeKernel(kernels[i], maxBonds); context.executeKernel(kernels[i], maxBonds);
} }
} }
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2014 Stanford University and the Authors. * * Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -33,7 +33,7 @@ using namespace OpenMM; ...@@ -33,7 +33,7 @@ using namespace OpenMM;
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
OpenCLExpressionUtilities::OpenCLExpressionUtilities(OpenCLContext& context) : context(context), fp1(1), fp2(2), fp3(3) { OpenCLExpressionUtilities::OpenCLExpressionUtilities(OpenCLContext& context) : context(context), fp1(1), fp2(2), fp3(3), periodicDistance(6) {
} }
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables, string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
...@@ -79,11 +79,6 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -79,11 +79,6 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName()); throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
case Operation::CUSTOM: case Operation::CUSTOM:
{ {
int i;
for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
;
if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
out << "0.0f;\n"; out << "0.0f;\n";
temps.push_back(make_pair(node, name)); temps.push_back(make_pair(node, name));
hasRecordedNode = true; hasRecordedNode = true;
...@@ -93,7 +88,7 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -93,7 +88,7 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
vector<const ExpressionTreeNode*> nodes; vector<const ExpressionTreeNode*> nodes;
for (int j = 0; j < (int) allExpressions.size(); j++) for (int j = 0; j < (int) allExpressions.size(); j++)
findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), nodes); findRelatedCustomFunctions(node, allExpressions[j].getRootNode(), nodes);
vector<string> nodeNames; vector<string> nodeNames;
nodeNames.push_back(name); nodeNames.push_back(name);
for (int j = 1; j < (int) nodes.size(); j++) { for (int j = 1; j < (int) nodes.size(); j++) {
...@@ -103,175 +98,222 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -103,175 +98,222 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
temps.push_back(make_pair(*nodes[j], name2)); temps.push_back(make_pair(*nodes[j], name2));
} }
out << "{\n"; out << "{\n";
vector<string> paramsFloat, paramsInt; if (node.getOperation().getName() == "periodicdistance") {
for (int j = 0; j < (int) functionParams[i].size(); j++) { // This is the periodicdistance() function.
paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
paramsInt.push_back(context.intToString((int) functionParams[i][j])); out << tempType << "3 periodicDistance_delta = (real3) (";
} for (int i = 0; i < 3; i++) {
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) { if (i > 0)
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << ", ";
out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n"; out << getTempName(node.getChildren()[i], temps) << "-" << getTempName(node.getChildren()[i+3], temps);
out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0)
out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
else
out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
} }
out << "}\n"; out << ");\n";
} out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) { out << tempType << " periodicDistance_rinv = RSQRT(periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z);\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
out << "float4 c[4];\n";
for (int j = 0; j < 4; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
for (int j = 0; j < nodes.size(); j++) { for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) { int argIndex = -1;
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[3].w*db + c[3].z)*db + c[3].y)*db + c[3].x;\n"; for (int k = 0; k < 6; k++) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[2].w*db + c[2].z)*db + c[2].y)*db + c[2].x;\n"; if (derivOrder[k] > 0) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[1].w*db + c[1].z)*db + c[1].y)*db + c[1].x;\n"; if (derivOrder[k] > 1 || argIndex != -1)
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[0].w*db + c[0].z)*db + c[0].y)*db + c[0].x;\n"; throw OpenMMException("Unsupported derivative of periodicdistance"); // Should be impossible for this to happen.
} argIndex = k;
else if (derivOrder[0] == 1 && derivOrder[1] == 0) { }
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].w*da + 2.0f*c[2].w)*da + c[1].w;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].z*da + 2.0f*c[2].z)*da + c[1].z;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].y*da + 2.0f*c[2].y)*da + c[1].y;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].x*da + 2.0f*c[2].x)*da + c[1].x;\n";
out << nodeNames[j] << " *= " << paramsFloat[6] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[3].w*db + 2.0f*c[3].z)*db + c[3].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[2].w*db + 2.0f*c[2].z)*db + c[2].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[1].w*db + 2.0f*c[1].z)*db + c[1].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[0].w*db + 2.0f*c[0].z)*db + c[0].y;\n";
out << nodeNames[j] << " *= " << paramsFloat[7] << ";\n";
} }
else if (argIndex == -1)
throw OpenMMException("Unsupported derivative order for Continuous2DFunction"); out << nodeNames[j] << " = RECIP(periodicDistance_rinv);\n";
else if (argIndex == 0)
out << nodeNames[j] << " = periodicDistance_delta.x*periodicDistance_rinv;\n";
else if (argIndex == 1)
out << nodeNames[j] << " = periodicDistance_delta.y*periodicDistance_rinv;\n";
else if (argIndex == 2)
out << nodeNames[j] << " = periodicDistance_delta.z*periodicDistance_rinv;\n";
else if (argIndex == 3)
out << nodeNames[j] << " = -periodicDistance_delta.x*periodicDistance_rinv;\n";
else if (argIndex == 4)
out << nodeNames[j] << " = -periodicDistance_delta.y*periodicDistance_rinv;\n";
else if (argIndex == 5)
out << nodeNames[j] << " = -periodicDistance_delta.z*periodicDistance_rinv;\n";
} }
out << "}\n";
} }
else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) { else {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; // This is a tabulated function.
out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "real z = " << getTempName(node.getChildren()[2], temps) << ";\n"; int i;
out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n"; for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n"; ;
out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n"; if (i == functionNames.size())
out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n"; throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n"; vector<string> paramsFloat, paramsInt;
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n"; for (int j = 0; j < (int) functionParams[i].size(); j++) {
out << "int u = min((int) floor(z), " << paramsInt[2] << ");\n"; paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n"; paramsInt.push_back(context.intToString((int) functionParams[i][j]));
out << "float4 c[16];\n"; }
for (int j = 0; j < 16; j++) if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real da = x-s;\n"; out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "real db = y-t;\n"; out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "real dc = z-u;\n"; out << "int index = (int) (floor(x));\n";
for (int j = 0; j < nodes.size(); j++) { out << "index = min(index, " << paramsInt[3] << ");\n";
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); out << "float4 coeff = " << functionNames[i].second << "[index];\n";
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) { out << "real b = x-index;\n";
out << "real value[4] = {0, 0, 0, 0};\n"; out << "real a = 1.0f-b;\n";
for (int k = 3; k >= 0; k--) for (int j = 0; j < nodes.size(); j++) {
for (int m = 0; m < 4; m++) { const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
int base = k + 4*m; if (derivOrder[0] == 0)
out << "value[" << m << "] = db*value[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n"; out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
} else
out << nodeNames[j] << " = value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));\n"; out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real derivx[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivx[" << m << "] = db*derivx[" << m << "] + (3*c[" << base << "].w*da + 2*c[" << base << "].z)*da + c[" << base << "].y;\n";
}
out << nodeNames[j] << " = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[9] << ";\n";
} }
else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0) { out << "}\n";
const string suffixes[] = {".x", ".y", ".z", ".w"}; }
out << "real derivy[4] = {0, 0, 0, 0};\n"; else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
for (int k = 3; k >= 0; k--) out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
for (int m = 0; m < 4; m++) { out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
int base = 4*m; out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
string suffix = suffixes[m]; out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
out << "derivy[" << m << "] = da*derivy[" << m << "] + (3*c[" << (base+3) << "]" << suffix << "*db + 2*c[" << (base+2) << "]" << suffix << ")*db + c[" << (base+1) << "]" << suffix << ";\n"; out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
} out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << nodeNames[j] << " = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));\n"; out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << nodeNames[j] << " *= " << paramsFloat[10] << ";\n"; out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
out << "float4 c[4];\n";
for (int j = 0; j < 4; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[3].w*db + c[3].z)*db + c[3].y)*db + c[3].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[2].w*db + c[2].z)*db + c[2].y)*db + c[2].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[1].w*db + c[1].z)*db + c[1].y)*db + c[1].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[0].w*db + c[0].z)*db + c[0].y)*db + c[0].x;\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0) {
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].w*da + 2.0f*c[2].w)*da + c[1].w;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].z*da + 2.0f*c[2].z)*da + c[1].z;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].y*da + 2.0f*c[2].y)*da + c[1].y;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].x*da + 2.0f*c[2].x)*da + c[1].x;\n";
out << nodeNames[j] << " *= " << paramsFloat[6] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[3].w*db + 2.0f*c[3].z)*db + c[3].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[2].w*db + 2.0f*c[2].z)*db + c[2].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[1].w*db + 2.0f*c[1].z)*db + c[1].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[0].w*db + 2.0f*c[0].z)*db + c[0].y;\n";
out << nodeNames[j] << " *= " << paramsFloat[7] << ";\n";
}
else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
} }
else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1) { out << "}\n";
out << "real derivz[4] = {0, 0, 0, 0};\n"; }
for (int k = 3; k >= 0; k--) else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
for (int m = 0; m < 4; m++) { out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
int base = k + 4*m; out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "derivz[" << m << "] = db*derivz[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n"; out << "real z = " << getTempName(node.getChildren()[2], temps) << ";\n";
} out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
out << nodeNames[j] << " = derivz[1] + dc*(2*derivz[2] + dc*3*derivz[3]);\n"; out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
out << nodeNames[j] << " *= " << paramsFloat[11] << ";\n"; out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << "int u = min((int) floor(z), " << paramsInt[2] << ");\n";
out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
out << "float4 c[16];\n";
for (int j = 0; j < 16; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
out << "real dc = z-u;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real value[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "value[" << m << "] = db*value[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
}
out << nodeNames[j] << " = value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real derivx[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivx[" << m << "] = db*derivx[" << m << "] + (3*c[" << base << "].w*da + 2*c[" << base << "].z)*da + c[" << base << "].y;\n";
}
out << nodeNames[j] << " = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[9] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0) {
const string suffixes[] = {".x", ".y", ".z", ".w"};
out << "real derivy[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = 4*m;
string suffix = suffixes[m];
out << "derivy[" << m << "] = da*derivy[" << m << "] + (3*c[" << (base+3) << "]" << suffix << "*db + 2*c[" << (base+2) << "]" << suffix << ")*db + c[" << (base+1) << "]" << suffix << ";\n";
}
out << nodeNames[j] << " = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[10] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1) {
out << "real derivz[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivz[" << m << "] = db*derivz[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
}
out << nodeNames[j] << " = derivz[1] + dc*(2*derivz[2] + dc*3*derivz[3]);\n";
out << nodeNames[j] << " *= " << paramsFloat[11] << ";\n";
}
else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
} }
else out << "}\n";
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
} }
out << "}\n"; else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
} for (int j = 0; j < nodes.size(); j++) {
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) { const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
for (int j = 0; j < nodes.size(); j++) { if (derivOrder[0] == 0) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
if (derivOrder[0] == 0) { out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "int index = (int) floor(x+0.5f);\n";
out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n"; out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << "int index = (int) floor(x+0.5f);\n"; out << "}\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n"; }
out << "}\n";
} }
} }
} else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) { for (int j = 0; j < nodes.size(); j++) {
for (int j = 0; j < nodes.size(); j++) { const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); if (derivOrder[0] == 0 && derivOrder[1] == 0) {
if (derivOrder[0] == 0 && derivOrder[1] == 0) { out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << "+0.5f);\n";
out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << "+0.5f);\n"; out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << "+0.5f);\n";
out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << "+0.5f);\n"; out << "int xsize = " << paramsInt[0] << ";\n";
out << "int xsize = " << paramsInt[0] << ";\n"; out << "int ysize = " << paramsInt[1] << ";\n";
out << "int ysize = " << paramsInt[1] << ";\n"; out << "int index = x+y*xsize;\n";
out << "int index = x+y*xsize;\n"; out << "if (index >= 0 && index < xsize*ysize)\n";
out << "if (index >= 0 && index < xsize*ysize)\n"; out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n"; }
} }
} }
} else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) { for (int j = 0; j < nodes.size(); j++) {
for (int j = 0; j < nodes.size(); j++) { const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) { out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << "+0.5f);\n";
out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << "+0.5f);\n"; out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << "+0.5f);\n";
out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << "+0.5f);\n"; out << "int z = (int) floor(" << getTempName(node.getChildren()[2], temps) << "+0.5f);\n";
out << "int z = (int) floor(" << getTempName(node.getChildren()[2], temps) << "+0.5f);\n"; out << "int xsize = " << paramsInt[0] << ";\n";
out << "int xsize = " << paramsInt[0] << ";\n"; out << "int ysize = " << paramsInt[1] << ";\n";
out << "int ysize = " << paramsInt[1] << ";\n"; out << "int zsize = " << paramsInt[2] << ";\n";
out << "int zsize = " << paramsInt[2] << ";\n"; out << "int index = x+(y+z*ysize)*xsize;\n";
out << "int index = x+(y+z*ysize)*xsize;\n"; out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
out << "if (index >= 0 && index < xsize*ysize*zsize)\n"; out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n"; }
} }
} }
} }
...@@ -475,7 +517,7 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co ...@@ -475,7 +517,7 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
throw OpenMMException(out.str()); throw OpenMMException(out.str());
} }
void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, void OpenCLExpressionUtilities::findRelatedCustomFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
vector<const Lepton::ExpressionTreeNode*>& nodes) { vector<const Lepton::ExpressionTreeNode*>& nodes) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) { if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
// Make sure the arguments are identical. // Make sure the arguments are identical.
...@@ -496,7 +538,7 @@ void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTr ...@@ -496,7 +538,7 @@ void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTr
} }
else else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++) for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes); findRelatedCustomFunctions(node, searchNode.getChildren()[i], nodes);
} }
void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) { void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
...@@ -722,3 +764,7 @@ Lepton::CustomFunction* OpenCLExpressionUtilities::getFunctionPlaceholder(const ...@@ -722,3 +764,7 @@ Lepton::CustomFunction* OpenCLExpressionUtilities::getFunctionPlaceholder(const
return &fp3; return &fp3;
throw OpenMMException("getFunctionPlaceholder: Unknown function type"); throw OpenMMException("getFunctionPlaceholder: Unknown function type");
} }
Lepton::CustomFunction* OpenCLExpressionUtilities::getPeriodicDistancePlaceholder() {
return &periodicDistance;
}
...@@ -3821,7 +3821,9 @@ void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const ...@@ -3821,7 +3821,9 @@ void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const
globalParamNames[i] = force.getGlobalParameterName(i); globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i); globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
} }
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize(); map<string, Lepton::CustomFunction*> customFunctions;
customFunctions["periodicdistance"] = cl.getExpressionUtilities().getPeriodicDistancePlaceholder();
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), customFunctions).optimize();
Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x").optimize(); Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x").optimize();
Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y").optimize(); Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y").optimize();
Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z").optimize(); Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z").optimize();
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2009 Stanford University and the Authors. * * Portions copyright (c) 2008-2015 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -161,6 +161,47 @@ void testParallelComputation() { ...@@ -161,6 +161,47 @@ void testParallelComputation() {
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5); ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5);
} }
void testPeriodic() {
Vec3 vx(5, 0, 0);
Vec3 vy(0, 6, 0);
Vec3 vz(1, 2, 7);
double x0 = 51, y0 = -17, z0 = 11.2;
System system;
system.setDefaultPeriodicBoxVectors(vx, vy, vz);
system.addParticle(1.0);
CustomExternalForce* force = new CustomExternalForce("periodicdistance(x, y, z, x0, y0, z0)^2");
force->addPerParticleParameter("x0");
force->addPerParticleParameter("y0");
force->addPerParticleParameter("z0");
vector<double> params(3);
params[0] = x0;
params[1] = y0;
params[2] = z0;
force->addParticle(0, params);
system.addForce(force);
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(0, 2, 0);
context.setPositions(positions);
for (int i = 0; i < 100; i++) {
State state = context.getState(State::Positions | State::Forces | State::Energy);
// Apply periodic boundary conditions to the difference between the two positions.
Vec3 delta = Vec3(x0, y0, z0)-state.getPositions()[0];
delta -= vz*floor(delta[2]/vz[2]+0.5);
delta -= vy*floor(delta[1]/vy[1]+0.5);
delta -= vx*floor(delta[0]/vx[0]+0.5);
// Verify that the force and energy are correct.
ASSERT_EQUAL_VEC(delta*2, state.getForces()[0], 1e-5);
ASSERT_EQUAL_TOL(delta.dot(delta), state.getPotentialEnergy(), 1e-5);
integrator.step(1);
}
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
if (argc > 1) if (argc > 1)
...@@ -168,6 +209,7 @@ int main(int argc, char* argv[]) { ...@@ -168,6 +209,7 @@ int main(int argc, char* argv[]) {
testForce(); testForce();
testManyParameters(); testManyParameters();
testParallelComputation(); testParallelComputation();
testPeriodic();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment