Commit 1763a764 authored by Peter Eastman's avatar Peter Eastman
Browse files

Fixed bug in computing chain rule terms for CustomGBForce

parent e918cb6f
...@@ -1793,16 +1793,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -1793,16 +1793,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
// Record derivatives of expressions needed for the chain rule terms. // Record derivatives of expressions needed for the chain rule terms.
vector<Lepton::ParsedExpression> valueDerivExpressions; vector<vector<Lepton::ParsedExpression> > energyDerivExpressions(force.getNumEnergyTerms());
vector<vector<Lepton::ParsedExpression> > energyDerivExpressions;
for (int i = 0; i < force.getNumComputedValues(); i++) {
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
if (i == 0)
valueDerivExpressions.push_back(ex.differentiate("r").optimize());
else
valueDerivExpressions.push_back(ex.differentiate(computedValueNames[i-1]).optimize());
}
energyDerivExpressions.resize(force.getNumEnergyTerms());
for (int i = 0; i < force.getNumEnergyTerms(); i++) { for (int i = 0; i < force.getNumEnergyTerms(); i++) {
string expression; string expression;
CustomGBForce::ComputationType type; CustomGBForce::ComputationType type;
...@@ -2118,11 +2109,11 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2118,11 +2109,11 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
map<string, Lepton::ParsedExpression> derivExpressions; map<string, Lepton::ParsedExpression> derivExpressions;
stringstream chainSource; stringstream chainSource;
Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize(); Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize();
derivExpressions["float dVdR1 = "] = dVdR; derivExpressions["float dV0dR1 = "] = dVdR;
derivExpressions["float dVdR2 = "] = dVdR.renameVariables(rename); derivExpressions["float dV0dR2 = "] = dVdR.renameVariables(rename);
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp0_", prefix+"functionParams"); chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp0_", prefix+"functionParams");
chainSource << "tempForce -= dVdR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "1") << ";\n"; chainSource << "tempForce -= dV0dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "1") << ";\n";
chainSource << "tempForce -= dVdR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "2") << ";\n"; chainSource << "tempForce -= dV0dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "2") << ";\n";
variables = globalVariables; variables = globalVariables;
map<string, string> rename1; map<string, string> rename1;
map<string, string> rename2; map<string, string> rename2;
...@@ -2141,16 +2132,17 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2141,16 +2132,17 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
rename2[name] = name+"2"; rename2[name] = name+"2";
if (i == 0) if (i == 0)
continue; continue;
Lepton::ParsedExpression dVdV = Lepton::Parser::parse(computedValueExpressions[1], functions).differentiate(computedValueNames[i-1]).optimize(); chainSource << "float dV"+intToString(i)+"dR1 = 0;\n";
string var = "dV"+intToString(i+1)+"dV"+intToString(i)+"_"; chainSource << "float dV"+intToString(i)+"dR2 = 0;\n";
for (int j = 0; j < i; j++) {
Lepton::ParsedExpression dVdV = Lepton::Parser::parse(computedValueExpressions[i], functions).differentiate(computedValueNames[j]).optimize();
derivExpressions.clear(); derivExpressions.clear();
derivExpressions["float "+var+"1 = "] = dVdV.renameVariables(rename1); derivExpressions["dV"+intToString(i)+"dR1 += dV"+intToString(j)+"dR1*"] = dVdV.renameVariables(rename1);
derivExpressions["float "+var+"2 = "] = dVdV.renameVariables(rename2); derivExpressions["dV"+intToString(i)+"dR2 += dV"+intToString(j)+"dR2*"] = dVdV.renameVariables(rename2);
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp"+intToString(i)+"_", prefix+"functionParams"); chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp"+intToString(i)+"_"+intToString(j)+"_", prefix+"functionParams");
chainSource << "dVdR1 *= "+var+"1;\n"; }
chainSource << "dVdR2 *= "+var+"2;\n"; chainSource << "tempForce -= dV"<< intToString(i) << "dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "1") << ";\n";
chainSource << "tempForce -= dVdR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "1") << ";\n"; chainSource << "tempForce -= dV"<< intToString(i) << "dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "2") << ";\n";
chainSource << "tempForce -= dVdR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "2") << ";\n";
} }
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = chainSource.str(); replacements["COMPUTE_FORCE"] = chainSource.str();
......
...@@ -172,6 +172,34 @@ void testTabulatedFunction(bool interpolating) { ...@@ -172,6 +172,34 @@ void testTabulatedFunction(bool interpolating) {
} }
} }
void testMultipleChainRules() {
OpenCLPlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomGBForce* force = new CustomGBForce();
force->addComputedValue("a", "2*r", CustomGBForce::ParticlePair);
force->addComputedValue("b", "a+1", CustomGBForce::SingleParticle);
force->addComputedValue("c", "2*b+a", CustomGBForce::SingleParticle);
force->addEnergyTerm("0.1*a+1*b+10*c", CustomGBForce::SingleParticle); // 0.1*(2*r) + 2*r+1 + 10*(3*a+2) = 0.2*r + 2*r+1 + 40*r+20+20*r = 62.2*r+21
force->addParticle(vector<double>());
force->addParticle(vector<double>());
system.addForce(force);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 1; i < 5; i++) {
positions[1] = Vec3(i, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(124.4, 0, 0), forces[0], 1e-4);
ASSERT_EQUAL_VEC(Vec3(-124.4, 0, 0), forces[1], 1e-4);
ASSERT_EQUAL_TOL(2*(62.2*i+21), state.getPotentialEnergy(), 0.02);
}
}
int main() { int main() {
try { try {
testOBC(GBSAOBCForce::NoCutoff, CustomGBForce::NoCutoff); testOBC(GBSAOBCForce::NoCutoff, CustomGBForce::NoCutoff);
...@@ -179,6 +207,7 @@ int main() { ...@@ -179,6 +207,7 @@ int main() {
testOBC(GBSAOBCForce::CutoffPeriodic, CustomGBForce::CutoffPeriodic); testOBC(GBSAOBCForce::CutoffPeriodic, CustomGBForce::CutoffPeriodic);
testTabulatedFunction(true); testTabulatedFunction(true);
testTabulatedFunction(false); testTabulatedFunction(false);
testMultipleChainRules();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
...@@ -1093,6 +1093,7 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1093,6 +1093,7 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
// Parse the expressions for computed values. // Parse the expressions for computed values.
valueDerivExpressions.resize(force.getNumComputedValues());
for (int i = 0; i < force.getNumComputedValues(); i++) { for (int i = 0; i < force.getNumComputedValues(); i++) {
string name, expression; string name, expression;
CustomGBForce::ComputationType type; CustomGBForce::ComputationType type;
...@@ -1101,10 +1102,12 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1101,10 +1102,12 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
valueExpressions.push_back(ex.createProgram()); valueExpressions.push_back(ex.createProgram());
valueTypes.push_back(type); valueTypes.push_back(type);
valueNames.push_back(name); valueNames.push_back(name);
if (type == CustomGBForce::SingleParticle) if (i == 0)
valueDerivExpressions.push_back(ex.differentiate(valueNames[i-1]).optimize().createProgram()); valueDerivExpressions[i].push_back(ex.differentiate("r").optimize().createProgram());
else else {
valueDerivExpressions.push_back(ex.differentiate("r").optimize().createProgram()); for (int j = 0; j < i; j++)
valueDerivExpressions[i].push_back(ex.differentiate(valueNames[j]).optimize().createProgram());
}
} }
// Parse the expressions for energy terms. // Parse the expressions for energy terms.
......
...@@ -584,7 +584,7 @@ private: ...@@ -584,7 +584,7 @@ private:
std::vector<std::set<int> > exclusions; std::vector<std::set<int> > exclusions;
std::vector<std::string> particleParameterNames, globalParameterNames, valueNames; std::vector<std::string> particleParameterNames, globalParameterNames, valueNames;
std::vector<Lepton::ExpressionProgram> valueExpressions; std::vector<Lepton::ExpressionProgram> valueExpressions;
std::vector<Lepton::ExpressionProgram> valueDerivExpressions; std::vector<std::vector<Lepton::ExpressionProgram> > valueDerivExpressions;
std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes; std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes;
std::vector<Lepton::ExpressionProgram> energyExpressions; std::vector<Lepton::ExpressionProgram> energyExpressions;
std::vector<std::vector<Lepton::ExpressionProgram> > energyDerivExpressions; std::vector<std::vector<Lepton::ExpressionProgram> > energyDerivExpressions;
......
...@@ -44,7 +44,7 @@ using std::vector; ...@@ -44,7 +44,7 @@ using std::vector;
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
ReferenceCustomGBIxn::ReferenceCustomGBIxn(const vector<Lepton::ExpressionProgram>& valueExpressions, ReferenceCustomGBIxn::ReferenceCustomGBIxn(const vector<Lepton::ExpressionProgram>& valueExpressions,
const vector<Lepton::ExpressionProgram>& valueDerivExpressions, const vector<vector<Lepton::ExpressionProgram> > valueDerivExpressions,
const vector<string>& valueNames, const vector<string>& valueNames,
const vector<OpenMM::CustomGBForce::ComputationType>& valueTypes, const vector<OpenMM::CustomGBForce::ComputationType>& valueTypes,
const vector<Lepton::ExpressionProgram>& energyExpressions, const vector<Lepton::ExpressionProgram>& energyExpressions,
...@@ -371,10 +371,11 @@ void ReferenceCustomGBIxn::calculateOnePairChainRule(int atom1, int atom2, RealO ...@@ -371,10 +371,11 @@ void ReferenceCustomGBIxn::calculateOnePairChainRule(int atom1, int atom2, RealO
// Evaluate the derivative of each parameter with respect to position and apply forces. // Evaluate the derivative of each parameter with respect to position and apply forces.
RealOpenMM dVdR = (RealOpenMM) valueDerivExpressions[0].evaluate(variables); vector<RealOpenMM> dVdR(valueDerivExpressions.size(), 0.0);
dVdR[0] = (RealOpenMM) valueDerivExpressions[0][0].evaluate(variables);
RealOpenMM rinv = 1/r; RealOpenMM rinv = 1/r;
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
RealOpenMM f = dEdV[0][atom1]*dVdR*deltaR[i]*rinv; RealOpenMM f = dEdV[0][atom1]*dVdR[0]*deltaR[i]*rinv;
forces[atom1][i] -= f; forces[atom1][i] -= f;
forces[atom2][i] += f; forces[atom2][i] += f;
} }
...@@ -384,9 +385,10 @@ void ReferenceCustomGBIxn::calculateOnePairChainRule(int atom1, int atom2, RealO ...@@ -384,9 +385,10 @@ void ReferenceCustomGBIxn::calculateOnePairChainRule(int atom1, int atom2, RealO
variables[valueNames[0]] = values[0][atom1]; variables[valueNames[0]] = values[0][atom1];
for (int i = 1; i < (int) valueNames.size(); i++) { for (int i = 1; i < (int) valueNames.size(); i++) {
variables[valueNames[i]] = values[i][atom1]; variables[valueNames[i]] = values[i][atom1];
dVdR *= (RealOpenMM) valueDerivExpressions[i].evaluate(variables); for (int j = 0; j < i; j++)
dVdR[i] += (RealOpenMM) (valueDerivExpressions[i][j].evaluate(variables)*dVdR[j]);
for (int j = 0; j < 3; j++) { for (int j = 0; j < 3; j++) {
RealOpenMM f = dEdV[i][atom1]*dVdR*deltaR[j]*rinv; RealOpenMM f = dEdV[i][atom1]*dVdR[i]*deltaR[j]*rinv;
forces[atom1][j] -= f; forces[atom1][j] -= f;
forces[atom2][j] += f; forces[atom2][j] += f;
} }
......
...@@ -44,7 +44,7 @@ class ReferenceCustomGBIxn { ...@@ -44,7 +44,7 @@ class ReferenceCustomGBIxn {
RealOpenMM periodicBoxSize[3]; RealOpenMM periodicBoxSize[3];
RealOpenMM cutoffDistance; RealOpenMM cutoffDistance;
std::vector<Lepton::ExpressionProgram> valueExpressions; std::vector<Lepton::ExpressionProgram> valueExpressions;
std::vector<Lepton::ExpressionProgram> valueDerivExpressions; std::vector<std::vector<Lepton::ExpressionProgram> > valueDerivExpressions;
std::vector<std::string> valueNames; std::vector<std::string> valueNames;
std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes; std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes;
std::vector<Lepton::ExpressionProgram> energyExpressions; std::vector<Lepton::ExpressionProgram> energyExpressions;
...@@ -221,7 +221,7 @@ class ReferenceCustomGBIxn { ...@@ -221,7 +221,7 @@ class ReferenceCustomGBIxn {
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
ReferenceCustomGBIxn(const std::vector<Lepton::ExpressionProgram>& valueExpressions, ReferenceCustomGBIxn(const std::vector<Lepton::ExpressionProgram>& valueExpressions,
const std::vector<Lepton::ExpressionProgram>& valueDerivExpressions, const std::vector<std::vector<Lepton::ExpressionProgram> > valueDerivExpressions,
const std::vector<std::string>& valueNames, const std::vector<std::string>& valueNames,
const std::vector<OpenMM::CustomGBForce::ComputationType>& valueTypes, const std::vector<OpenMM::CustomGBForce::ComputationType>& valueTypes,
const std::vector<Lepton::ExpressionProgram>& energyExpressions, const std::vector<Lepton::ExpressionProgram>& energyExpressions,
......
...@@ -172,6 +172,34 @@ void testTabulatedFunction(bool interpolating) { ...@@ -172,6 +172,34 @@ void testTabulatedFunction(bool interpolating) {
} }
} }
void testMultipleChainRules() {
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomGBForce* force = new CustomGBForce();
force->addComputedValue("a", "2*r", CustomGBForce::ParticlePair);
force->addComputedValue("b", "a+1", CustomGBForce::SingleParticle);
force->addComputedValue("c", "2*b+a", CustomGBForce::SingleParticle);
force->addEnergyTerm("0.1*a+1*b+10*c", CustomGBForce::SingleParticle); // 0.1*(2*r) + 2*r+1 + 10*(3*a+2) = 0.2*r + 2*r+1 + 40*r+20+20*r = 62.2*r+21
force->addParticle(vector<double>());
force->addParticle(vector<double>());
system.addForce(force);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 1; i < 5; i++) {
positions[1] = Vec3(i, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(124.4, 0, 0), forces[0], 1e-4);
ASSERT_EQUAL_VEC(Vec3(-124.4, 0, 0), forces[1], 1e-4);
ASSERT_EQUAL_TOL(2*(62.2*i+21), state.getPotentialEnergy(), 0.02);
}
}
int main() { int main() {
try { try {
testOBC(GBSAOBCForce::NoCutoff, CustomGBForce::NoCutoff); testOBC(GBSAOBCForce::NoCutoff, CustomGBForce::NoCutoff);
...@@ -179,6 +207,7 @@ int main() { ...@@ -179,6 +207,7 @@ int main() {
testOBC(GBSAOBCForce::CutoffPeriodic, CustomGBForce::CutoffPeriodic); testOBC(GBSAOBCForce::CutoffPeriodic, CustomGBForce::CutoffPeriodic);
testTabulatedFunction(true); testTabulatedFunction(true);
testTabulatedFunction(false); testTabulatedFunction(false);
testMultipleChainRules();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment