Commit 0b77ab91 authored by peastman's avatar peastman
Browse files

Began CUDA implementation of periodicdistance() function for CustomExternalForce

parent cceb3171
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2014 Stanford University and the Authors. *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -89,6 +89,10 @@ public:
* @param function the function for which to get a placeholder
*/
Lepton::CustomFunction* getFunctionPlaceholder(const TabulatedFunction& function);
/**
* Get a Lepton::CustomFunction that can be used to represent the periodicdistance() function when parsing expressions.
*/
Lepton::CustomFunction* getPeriodicDistancePlaceholder();
private:
class FunctionPlaceholder : public Lepton::CustomFunction {
public:
......@@ -114,13 +118,13 @@ private:
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
void findRelatedCustomFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
CudaContext& context;
FunctionPlaceholder fp1, fp2, fp3;
FunctionPlaceholder fp1, fp2, fp3, periodicDistance;
};
} // namespace OpenMM
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2014 Stanford University and the Authors. *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -33,7 +33,7 @@ using namespace OpenMM;
using namespace Lepton;
using namespace std;
CudaExpressionUtilities::CudaExpressionUtilities(CudaContext& context) : context(context), fp1(1), fp2(2), fp3(3) {
CudaExpressionUtilities::CudaExpressionUtilities(CudaContext& context) : context(context), fp1(1), fp2(2), fp3(3), periodicDistance(6) {
}
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
......@@ -79,11 +79,6 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
case Operation::CUSTOM:
{
int i;
for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
;
if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
out << "0.0f;\n";
temps.push_back(make_pair(node, name));
hasRecordedNode = true;
......@@ -93,7 +88,7 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
vector<const ExpressionTreeNode*> nodes;
for (int j = 0; j < (int) allExpressions.size(); j++)
findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), nodes);
findRelatedCustomFunctions(node, allExpressions[j].getRootNode(), nodes);
vector<string> nodeNames;
nodeNames.push_back(name);
for (int j = 1; j < (int) nodes.size(); j++) {
......@@ -103,175 +98,222 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
temps.push_back(make_pair(*nodes[j], name2));
}
out << "{\n";
vector<string> paramsFloat, paramsInt;
for (int j = 0; j < (int) functionParams[i].size(); j++) {
paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
paramsInt.push_back(context.intToString((int) functionParams[i][j]));
}
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0)
out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
else
out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
if (node.getOperation().getName() == "periodicdistance") {
// This is the periodicdistance() function.
out << tempType << " periodicDistance_delta = make_real3(";
for (int i = 0; i < 3; i++) {
if (i > 0)
out << ", ";
out << getTempName(node.getChildren()[i], temps) << "-" << getTempName(node.getChildren()[i+3], temps);
out << ");\n";
}
out << "}\n";
}
else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
out << "float4 c[4];\n";
for (int j = 0; j < 4; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
out << tempType << " periodicDistance_r = SQRT(periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z);\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[3].w*db + c[3].z)*db + c[3].y)*db + c[3].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[2].w*db + c[2].z)*db + c[2].y)*db + c[2].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[1].w*db + c[1].z)*db + c[1].y)*db + c[1].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[0].w*db + c[0].z)*db + c[0].y)*db + c[0].x;\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0) {
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].w*da + 2.0f*c[2].w)*da + c[1].w;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].z*da + 2.0f*c[2].z)*da + c[1].z;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].y*da + 2.0f*c[2].y)*da + c[1].y;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].x*da + 2.0f*c[2].x)*da + c[1].x;\n";
out << nodeNames[j] << " *= " << paramsFloat[6] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[3].w*db + 2.0f*c[3].z)*db + c[3].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[2].w*db + 2.0f*c[2].z)*db + c[2].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[1].w*db + 2.0f*c[1].z)*db + c[1].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[0].w*db + 2.0f*c[0].z)*db + c[0].y;\n";
out << nodeNames[j] << " *= " << paramsFloat[7] << ";\n";
int argIndex = -1;
for (int k = 0; k < 6; k++) {
if (derivOrder[k] > 0) {
if (derivOrder[k] > 1 || argIndex != -1)
throw OpenMMException("Unsupported derivative of periodicdistance"); // Should be impossible for this to happen.
argIndex = k;
}
}
else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
if (argIndex == -1)
out << nodeNames[j] << " = periodicDistance_r;\n";
else if (argIndex == 0)
out << nodeNames[j] << " = periodicDistance_delta.x/periodicDistance_r;\n";
else if (argIndex == 1)
out << nodeNames[j] << " = periodicDistance_delta.y/periodicDistance_r;\n";
else if (argIndex == 2)
out << nodeNames[j] << " = periodicDistance_delta.z/periodicDistance_r;\n";
else if (argIndex == 3)
out << nodeNames[j] << " = -periodicDistance_delta.x/periodicDistance_r;\n";
else if (argIndex == 4)
out << nodeNames[j] << " = -periodicDistance_delta.y/periodicDistance_r;\n";
else if (argIndex == 5)
out << nodeNames[j] << " = -periodicDistance_delta.z/periodicDistance_r;\n";
}
out << "}\n";
}
else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "real z = " << getTempName(node.getChildren()[2], temps) << ";\n";
out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << "int u = min((int) floor(z), " << paramsInt[2] << ");\n";
out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
out << "float4 c[16];\n";
for (int j = 0; j < 16; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
out << "real dc = z-u;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real value[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "value[" << m << "] = db*value[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
}
out << nodeNames[j] << " = value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real derivx[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivx[" << m << "] = db*derivx[" << m << "] + (3*c[" << base << "].w*da + 2*c[" << base << "].z)*da + c[" << base << "].y;\n";
}
out << nodeNames[j] << " = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[9] << ";\n";
else {
// This is a tabulated function.
int i;
for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
;
if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
vector<string> paramsFloat, paramsInt;
for (int j = 0; j < (int) functionParams[i].size(); j++) {
paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
paramsInt.push_back(context.intToString((int) functionParams[i][j]));
}
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0)
out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
else
out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0) {
const string suffixes[] = {".x", ".y", ".z", ".w"};
out << "real derivy[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = 4*m;
string suffix = suffixes[m];
out << "derivy[" << m << "] = da*derivy[" << m << "] + (3*c[" << (base+3) << "]" << suffix << "*db + 2*c[" << (base+2) << "]" << suffix << ")*db + c[" << (base+1) << "]" << suffix << ";\n";
}
out << nodeNames[j] << " = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[10] << ";\n";
out << "}\n";
}
else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
out << "float4 c[4];\n";
for (int j = 0; j < 4; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[3].w*db + c[3].z)*db + c[3].y)*db + c[3].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[2].w*db + c[2].z)*db + c[2].y)*db + c[2].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[1].w*db + c[1].z)*db + c[1].y)*db + c[1].x;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + ((c[0].w*db + c[0].z)*db + c[0].y)*db + c[0].x;\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0) {
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].w*da + 2.0f*c[2].w)*da + c[1].w;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].z*da + 2.0f*c[2].z)*da + c[1].z;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].y*da + 2.0f*c[2].y)*da + c[1].y;\n";
out << nodeNames[j] << " = db*" << nodeNames[j] << " + (3.0f*c[3].x*da + 2.0f*c[2].x)*da + c[1].x;\n";
out << nodeNames[j] << " *= " << paramsFloat[6] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1) {
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[3].w*db + 2.0f*c[3].z)*db + c[3].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[2].w*db + 2.0f*c[2].z)*db + c[2].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[1].w*db + 2.0f*c[1].z)*db + c[1].y;\n";
out << nodeNames[j] << " = da*" << nodeNames[j] << " + (3.0f*c[0].w*db + 2.0f*c[0].z)*db + c[0].y;\n";
out << nodeNames[j] << " *= " << paramsFloat[7] << ";\n";
}
else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
}
else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1) {
out << "real derivz[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivz[" << m << "] = db*derivz[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
}
out << nodeNames[j] << " = derivz[1] + dc*(2*derivz[2] + dc*3*derivz[3]);\n";
out << nodeNames[j] << " *= " << paramsFloat[11] << ";\n";
out << "}\n";
}
else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << ";\n";
out << "real z = " << getTempName(node.getChildren()[2], temps) << ";\n";
out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[11] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << ");\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << ");\n";
out << "int u = min((int) floor(z), " << paramsInt[2] << ");\n";
out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
out << "float4 c[16];\n";
for (int j = 0; j < 16; j++)
out << "c[" << j << "] = " << functionNames[i].second << "[coeffIndex+" << j << "];\n";
out << "real da = x-s;\n";
out << "real db = y-t;\n";
out << "real dc = z-u;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real value[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "value[" << m << "] = db*value[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
}
out << nodeNames[j] << " = value[0] + dc*(value[1] + dc*(value[2] + dc*value[3]));\n";
}
else if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "real derivx[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivx[" << m << "] = db*derivx[" << m << "] + (3*c[" << base << "].w*da + 2*c[" << base << "].z)*da + c[" << base << "].y;\n";
}
out << nodeNames[j] << " = derivx[0] + dc*(derivx[1] + dc*(derivx[2] + dc*derivx[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[9] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0) {
const string suffixes[] = {".x", ".y", ".z", ".w"};
out << "real derivy[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = 4*m;
string suffix = suffixes[m];
out << "derivy[" << m << "] = da*derivy[" << m << "] + (3*c[" << (base+3) << "]" << suffix << "*db + 2*c[" << (base+2) << "]" << suffix << ")*db + c[" << (base+1) << "]" << suffix << ";\n";
}
out << nodeNames[j] << " = derivy[0] + dc*(derivy[1] + dc*(derivy[2] + dc*derivy[3]));\n";
out << nodeNames[j] << " *= " << paramsFloat[10] << ";\n";
}
else if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1) {
out << "real derivz[4] = {0, 0, 0, 0};\n";
for (int k = 3; k >= 0; k--)
for (int m = 0; m < 4; m++) {
int base = k + 4*m;
out << "derivz[" << m << "] = db*derivz[" << m << "] + ((c[" << base << "].w*da + c[" << base << "].z)*da + c[" << base << "].y)*da + c[" << base << "].x;\n";
}
out << nodeNames[j] << " = derivz[1] + dc*(2*derivz[2] + dc*3*derivz[3]);\n";
out << nodeNames[j] << " *= " << paramsFloat[11] << ";\n";
}
else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
}
else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
out << "}\n";
}
out << "}\n";
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
out << "int index = (int) floor(x+0.5f);\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
out << "int index = (int) floor(x+0.5f);\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
}
}
}
}
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << "+0.5f);\n";
out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << "+0.5f);\n";
out << "int xsize = (int) " << paramsInt[0] << ";\n";
out << "int ysize = (int) " << paramsInt[1] << ";\n";
out << "int index = x+y*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << "+0.5f);\n";
out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << "+0.5f);\n";
out << "int xsize = (int) " << paramsInt[0] << ";\n";
out << "int ysize = (int) " << paramsInt[1] << ";\n";
out << "int index = x+y*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
}
}
}
}
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << "+0.5f);\n";
out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << "+0.5f);\n";
out << "int z = (int) floor(" << getTempName(node.getChildren()[2], temps) << "+0.5f);\n";
out << "int xsize = (int) " << paramsInt[0] << ";\n";
out << "int ysize = (int) " << paramsInt[1] << ";\n";
out << "int zsize = (int) " << paramsInt[2] << ";\n";
out << "int index = x+(y+z*ysize)*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "int x = (int) floor(" << getTempName(node.getChildren()[0], temps) << "+0.5f);\n";
out << "int y = (int) floor(" << getTempName(node.getChildren()[1], temps) << "+0.5f);\n";
out << "int z = (int) floor(" << getTempName(node.getChildren()[2], temps) << "+0.5f);\n";
out << "int xsize = (int) " << paramsInt[0] << ";\n";
out << "int ysize = (int) " << paramsInt[1] << ";\n";
out << "int zsize = (int) " << paramsInt[2] << ";\n";
out << "int index = x+(y+z*ysize)*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
}
}
}
}
......@@ -483,7 +525,7 @@ string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, cons
throw OpenMMException(out.str());
}
void CudaExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
void CudaExpressionUtilities::findRelatedCustomFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
vector<const Lepton::ExpressionTreeNode*>& nodes) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
// Make sure the arguments are identical.
......@@ -504,7 +546,7 @@ void CudaExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTree
}
else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes);
findRelatedCustomFunctions(node, searchNode.getChildren()[i], nodes);
}
void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
......@@ -730,3 +772,7 @@ Lepton::CustomFunction* CudaExpressionUtilities::getFunctionPlaceholder(const Ta
return &fp3;
throw OpenMMException("getFunctionPlaceholder: Unknown function type");
}
Lepton::CustomFunction* CudaExpressionUtilities::getPeriodicDistancePlaceholder() {
return &periodicDistance;
}
\ No newline at end of file
......@@ -3652,7 +3652,9 @@ void CudaCalcCustomExternalForceKernel::initialize(const System& system, const C
globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
}
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize();
map<string, Lepton::CustomFunction*> customFunctions;
customFunctions["periodicdistance"] = cu.getExpressionUtilities().getPeriodicDistancePlaceholder();
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), customFunctions).optimize();
Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x").optimize();
Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y").optimize();
Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z").optimize();
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -161,6 +161,47 @@ void testParallelComputation() {
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5);
}
void testPeriodic() {
Vec3 vx(5, 0, 0);
Vec3 vy(0, 6, 0);
Vec3 vz(1, 2, 7);
double x0 = 51, y0 = -17, z0 = 11.2;
System system;
system.setDefaultPeriodicBoxVectors(vx, vy, vz);
system.addParticle(1.0);
CustomExternalForce* force = new CustomExternalForce("periodicdistance(x, y, z, x0, y0, z0)^2");
force->addPerParticleParameter("x0");
force->addPerParticleParameter("y0");
force->addPerParticleParameter("z0");
vector<double> params(3);
params[0] = x0;
params[1] = y0;
params[2] = z0;
force->addParticle(0, params);
system.addForce(force);
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(0, 2, 0);
context.setPositions(positions);
for (int i = 0; i < 100; i++) {
State state = context.getState(State::Positions | State::Forces | State::Energy);
// Apply periodic boundary conditions to the difference between the two positions.
Vec3 delta = Vec3(x0, y0, z0)-state.getPositions()[0];
delta -= vz*floor(delta[2]/vz[2]+0.5);
delta -= vy*floor(delta[1]/vy[1]+0.5);
delta -= vx*floor(delta[0]/vx[0]+0.5);
// Verify that the force and energy are correct.
ASSERT_EQUAL_VEC(delta*2, state.getForces()[0], 1e-6);
ASSERT_EQUAL_TOL(delta.dot(delta), state.getPotentialEnergy(), 1e-6);
integrator.step(1);
}
}
int main(int argc, char* argv[]) {
try {
if (argc > 1)
......@@ -168,6 +209,7 @@ int main(int argc, char* argv[]) {
testForce();
testManyParameters();
testParallelComputation();
testPeriodic();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2009 Stanford University and the Authors. *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment