Commit 0b77ab91 authored by peastman's avatar peastman
Browse files

Began CUDA implementation of periodicdistance() function for CustomExternalForce

parent cceb3171
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2014 Stanford University and the Authors. * * Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -89,6 +89,10 @@ public: ...@@ -89,6 +89,10 @@ public:
* @param function the function for which to get a placeholder * @param function the function for which to get a placeholder
*/ */
Lepton::CustomFunction* getFunctionPlaceholder(const TabulatedFunction& function); Lepton::CustomFunction* getFunctionPlaceholder(const TabulatedFunction& function);
/**
* Get a Lepton::CustomFunction that can be used to represent the periodicdistance() function when parsing expressions.
*/
Lepton::CustomFunction* getPeriodicDistancePlaceholder();
private: private:
class FunctionPlaceholder : public Lepton::CustomFunction { class FunctionPlaceholder : public Lepton::CustomFunction {
public: public:
...@@ -114,13 +118,13 @@ private: ...@@ -114,13 +118,13 @@ private:
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType); const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps); std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedCustomFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::vector<const Lepton::ExpressionTreeNode*>& nodes); std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers); std::map<int, const Lepton::ExpressionTreeNode*>& powers);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions); std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
CudaContext& context; CudaContext& context;
FunctionPlaceholder fp1, fp2, fp3; FunctionPlaceholder fp1, fp2, fp3, periodicDistance;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2014 Stanford University and the Authors. * * Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -33,7 +33,7 @@ using namespace OpenMM; ...@@ -33,7 +33,7 @@ using namespace OpenMM;
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
CudaExpressionUtilities::CudaExpressionUtilities(CudaContext& context) : context(context), fp1(1), fp2(2), fp3(3) { CudaExpressionUtilities::CudaExpressionUtilities(CudaContext& context) : context(context), fp1(1), fp2(2), fp3(3), periodicDistance(6) {
} }
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables, string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
...@@ -79,11 +79,6 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -79,11 +79,6 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName()); throw OpenMMException("Unknown variable in expression: "+node.getOperation().getName());
case Operation::CUSTOM: case Operation::CUSTOM:
{ {
int i;
for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
;
if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
out << "0.0f;\n"; out << "0.0f;\n";
temps.push_back(make_pair(node, name)); temps.push_back(make_pair(node, name));
hasRecordedNode = true; hasRecordedNode = true;
...@@ -93,7 +88,7 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -93,7 +88,7 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
vector<const ExpressionTreeNode*> nodes; vector<const ExpressionTreeNode*> nodes;
for (int j = 0; j < (int) allExpressions.size(); j++) for (int j = 0; j < (int) allExpressions.size(); j++)
findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), nodes); findRelatedCustomFunctions(node, allExpressions[j].getRootNode(), nodes);
vector<string> nodeNames; vector<string> nodeNames;
nodeNames.push_back(name); nodeNames.push_back(name);
for (int j = 1; j < (int) nodes.size(); j++) { for (int j = 1; j < (int) nodes.size(); j++) {
...@@ -103,6 +98,52 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -103,6 +98,52 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
temps.push_back(make_pair(*nodes[j], name2)); temps.push_back(make_pair(*nodes[j], name2));
} }
out << "{\n"; out << "{\n";
if (node.getOperation().getName() == "periodicdistance") {
// This is the periodicdistance() function.
out << tempType << " periodicDistance_delta = make_real3(";
for (int i = 0; i < 3; i++) {
if (i > 0)
out << ", ";
out << getTempName(node.getChildren()[i], temps) << "-" << getTempName(node.getChildren()[i+3], temps);
out << ");\n";
}
out << "APPLY_PERIODIC_TO_DELTA(periodicDistance_delta)\n";
out << tempType << " periodicDistance_r = SQRT(periodicDistance_delta.x*periodicDistance_delta.x + periodicDistance_delta.y*periodicDistance_delta.y + periodicDistance_delta.z*periodicDistance_delta.z);\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
int argIndex = -1;
for (int k = 0; k < 6; k++) {
if (derivOrder[k] > 0) {
if (derivOrder[k] > 1 || argIndex != -1)
throw OpenMMException("Unsupported derivative of periodicdistance"); // Should be impossible for this to happen.
argIndex = k;
}
}
if (argIndex == -1)
out << nodeNames[j] << " = periodicDistance_r;\n";
else if (argIndex == 0)
out << nodeNames[j] << " = periodicDistance_delta.x/periodicDistance_r;\n";
else if (argIndex == 1)
out << nodeNames[j] << " = periodicDistance_delta.y/periodicDistance_r;\n";
else if (argIndex == 2)
out << nodeNames[j] << " = periodicDistance_delta.z/periodicDistance_r;\n";
else if (argIndex == 3)
out << nodeNames[j] << " = -periodicDistance_delta.x/periodicDistance_r;\n";
else if (argIndex == 4)
out << nodeNames[j] << " = -periodicDistance_delta.y/periodicDistance_r;\n";
else if (argIndex == 5)
out << nodeNames[j] << " = -periodicDistance_delta.z/periodicDistance_r;\n";
}
}
else {
// This is a tabulated function.
int i;
for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
;
if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
vector<string> paramsFloat, paramsInt; vector<string> paramsFloat, paramsInt;
for (int j = 0; j < (int) functionParams[i].size(); j++) { for (int j = 0; j < (int) functionParams[i].size(); j++) {
paramsFloat.push_back(context.doubleToString(functionParams[i][j])); paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
...@@ -275,6 +316,7 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -275,6 +316,7 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
} }
} }
} }
}
out << "}"; out << "}";
break; break;
} }
...@@ -483,7 +525,7 @@ string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, cons ...@@ -483,7 +525,7 @@ string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, cons
throw OpenMMException(out.str()); throw OpenMMException(out.str());
} }
void CudaExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, void CudaExpressionUtilities::findRelatedCustomFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
vector<const Lepton::ExpressionTreeNode*>& nodes) { vector<const Lepton::ExpressionTreeNode*>& nodes) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) { if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getOperation().getName() == searchNode.getOperation().getName()) {
// Make sure the arguments are identical. // Make sure the arguments are identical.
...@@ -504,7 +546,7 @@ void CudaExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTree ...@@ -504,7 +546,7 @@ void CudaExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTree
} }
else else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++) for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes); findRelatedCustomFunctions(node, searchNode.getChildren()[i], nodes);
} }
void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) { void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
...@@ -730,3 +772,7 @@ Lepton::CustomFunction* CudaExpressionUtilities::getFunctionPlaceholder(const Ta ...@@ -730,3 +772,7 @@ Lepton::CustomFunction* CudaExpressionUtilities::getFunctionPlaceholder(const Ta
return &fp3; return &fp3;
throw OpenMMException("getFunctionPlaceholder: Unknown function type"); throw OpenMMException("getFunctionPlaceholder: Unknown function type");
} }
Lepton::CustomFunction* CudaExpressionUtilities::getPeriodicDistancePlaceholder() {
return &periodicDistance;
}
\ No newline at end of file
...@@ -3652,7 +3652,9 @@ void CudaCalcCustomExternalForceKernel::initialize(const System& system, const C ...@@ -3652,7 +3652,9 @@ void CudaCalcCustomExternalForceKernel::initialize(const System& system, const C
globalParamNames[i] = force.getGlobalParameterName(i); globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i); globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
} }
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize(); map<string, Lepton::CustomFunction*> customFunctions;
customFunctions["periodicdistance"] = cu.getExpressionUtilities().getPeriodicDistancePlaceholder();
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), customFunctions).optimize();
Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x").optimize(); Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x").optimize();
Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y").optimize(); Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y").optimize();
Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z").optimize(); Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z").optimize();
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2012 Stanford University and the Authors. * * Portions copyright (c) 2008-2015 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -161,6 +161,47 @@ void testParallelComputation() { ...@@ -161,6 +161,47 @@ void testParallelComputation() {
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5); ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5);
} }
void testPeriodic() {
Vec3 vx(5, 0, 0);
Vec3 vy(0, 6, 0);
Vec3 vz(1, 2, 7);
double x0 = 51, y0 = -17, z0 = 11.2;
System system;
system.setDefaultPeriodicBoxVectors(vx, vy, vz);
system.addParticle(1.0);
CustomExternalForce* force = new CustomExternalForce("periodicdistance(x, y, z, x0, y0, z0)^2");
force->addPerParticleParameter("x0");
force->addPerParticleParameter("y0");
force->addPerParticleParameter("z0");
vector<double> params(3);
params[0] = x0;
params[1] = y0;
params[2] = z0;
force->addParticle(0, params);
system.addForce(force);
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(0, 2, 0);
context.setPositions(positions);
for (int i = 0; i < 100; i++) {
State state = context.getState(State::Positions | State::Forces | State::Energy);
// Apply periodic boundary conditions to the difference between the two positions.
Vec3 delta = Vec3(x0, y0, z0)-state.getPositions()[0];
delta -= vz*floor(delta[2]/vz[2]+0.5);
delta -= vy*floor(delta[1]/vy[1]+0.5);
delta -= vx*floor(delta[0]/vx[0]+0.5);
// Verify that the force and energy are correct.
ASSERT_EQUAL_VEC(delta*2, state.getForces()[0], 1e-6);
ASSERT_EQUAL_TOL(delta.dot(delta), state.getPotentialEnergy(), 1e-6);
integrator.step(1);
}
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
if (argc > 1) if (argc > 1)
...@@ -168,6 +209,7 @@ int main(int argc, char* argv[]) { ...@@ -168,6 +209,7 @@ int main(int argc, char* argv[]) {
testForce(); testForce();
testManyParameters(); testManyParameters();
testParallelComputation(); testParallelComputation();
testPeriodic();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2009 Stanford University and the Authors. * * Portions copyright (c) 2008-2015 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment