"olla/vscode:/vscode.git/clone" did not exist on "fa43cfd7e474f0c64fcf70e81f39b327ae88ae07"
Commit 1763a764 authored by Peter Eastman's avatar Peter Eastman
Browse files

Fixed bug in computing chain rule terms for CustomGBForce

parent e918cb6f
......@@ -1793,16 +1793,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
// Record derivatives of expressions needed for the chain rule terms.
vector<Lepton::ParsedExpression> valueDerivExpressions;
vector<vector<Lepton::ParsedExpression> > energyDerivExpressions;
for (int i = 0; i < force.getNumComputedValues(); i++) {
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
if (i == 0)
valueDerivExpressions.push_back(ex.differentiate("r").optimize());
else
valueDerivExpressions.push_back(ex.differentiate(computedValueNames[i-1]).optimize());
}
energyDerivExpressions.resize(force.getNumEnergyTerms());
vector<vector<Lepton::ParsedExpression> > energyDerivExpressions(force.getNumEnergyTerms());
for (int i = 0; i < force.getNumEnergyTerms(); i++) {
string expression;
CustomGBForce::ComputationType type;
......@@ -2118,11 +2109,11 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
map<string, Lepton::ParsedExpression> derivExpressions;
stringstream chainSource;
Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize();
derivExpressions["float dVdR1 = "] = dVdR;
derivExpressions["float dVdR2 = "] = dVdR.renameVariables(rename);
derivExpressions["float dV0dR1 = "] = dVdR;
derivExpressions["float dV0dR2 = "] = dVdR.renameVariables(rename);
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp0_", prefix+"functionParams");
chainSource << "tempForce -= dVdR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "1") << ";\n";
chainSource << "tempForce -= dVdR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "2") << ";\n";
chainSource << "tempForce -= dV0dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "1") << ";\n";
chainSource << "tempForce -= dV0dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "2") << ";\n";
variables = globalVariables;
map<string, string> rename1;
map<string, string> rename2;
......@@ -2141,16 +2132,17 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
rename2[name] = name+"2";
if (i == 0)
continue;
Lepton::ParsedExpression dVdV = Lepton::Parser::parse(computedValueExpressions[1], functions).differentiate(computedValueNames[i-1]).optimize();
string var = "dV"+intToString(i+1)+"dV"+intToString(i)+"_";
derivExpressions.clear();
derivExpressions["float "+var+"1 = "] = dVdV.renameVariables(rename1);
derivExpressions["float "+var+"2 = "] = dVdV.renameVariables(rename2);
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp"+intToString(i)+"_", prefix+"functionParams");
chainSource << "dVdR1 *= "+var+"1;\n";
chainSource << "dVdR2 *= "+var+"2;\n";
chainSource << "tempForce -= dVdR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "1") << ";\n";
chainSource << "tempForce -= dVdR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "2") << ";\n";
chainSource << "float dV"+intToString(i)+"dR1 = 0;\n";
chainSource << "float dV"+intToString(i)+"dR2 = 0;\n";
for (int j = 0; j < i; j++) {
Lepton::ParsedExpression dVdV = Lepton::Parser::parse(computedValueExpressions[i], functions).differentiate(computedValueNames[j]).optimize();
derivExpressions.clear();
derivExpressions["dV"+intToString(i)+"dR1 += dV"+intToString(j)+"dR1*"] = dVdV.renameVariables(rename1);
derivExpressions["dV"+intToString(i)+"dR2 += dV"+intToString(j)+"dR2*"] = dVdV.renameVariables(rename2);
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp"+intToString(i)+"_"+intToString(j)+"_", prefix+"functionParams");
}
chainSource << "tempForce -= dV"<< intToString(i) << "dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "1") << ";\n";
chainSource << "tempForce -= dV"<< intToString(i) << "dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "2") << ";\n";
}
map<string, string> replacements;
replacements["COMPUTE_FORCE"] = chainSource.str();
......
......@@ -172,6 +172,34 @@ void testTabulatedFunction(bool interpolating) {
}
}
void testMultipleChainRules() {
OpenCLPlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomGBForce* force = new CustomGBForce();
force->addComputedValue("a", "2*r", CustomGBForce::ParticlePair);
force->addComputedValue("b", "a+1", CustomGBForce::SingleParticle);
force->addComputedValue("c", "2*b+a", CustomGBForce::SingleParticle);
force->addEnergyTerm("0.1*a+1*b+10*c", CustomGBForce::SingleParticle); // 0.1*(2*r) + 2*r+1 + 10*(3*a+2) = 0.2*r + 2*r+1 + 40*r+20+20*r = 62.2*r+21
force->addParticle(vector<double>());
force->addParticle(vector<double>());
system.addForce(force);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 1; i < 5; i++) {
positions[1] = Vec3(i, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(124.4, 0, 0), forces[0], 1e-4);
ASSERT_EQUAL_VEC(Vec3(-124.4, 0, 0), forces[1], 1e-4);
ASSERT_EQUAL_TOL(2*(62.2*i+21), state.getPotentialEnergy(), 0.02);
}
}
int main() {
try {
testOBC(GBSAOBCForce::NoCutoff, CustomGBForce::NoCutoff);
......@@ -179,6 +207,7 @@ int main() {
testOBC(GBSAOBCForce::CutoffPeriodic, CustomGBForce::CutoffPeriodic);
testTabulatedFunction(true);
testTabulatedFunction(false);
testMultipleChainRules();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
......@@ -1093,6 +1093,7 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
// Parse the expressions for computed values.
valueDerivExpressions.resize(force.getNumComputedValues());
for (int i = 0; i < force.getNumComputedValues(); i++) {
string name, expression;
CustomGBForce::ComputationType type;
......@@ -1101,10 +1102,12 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
valueExpressions.push_back(ex.createProgram());
valueTypes.push_back(type);
valueNames.push_back(name);
if (type == CustomGBForce::SingleParticle)
valueDerivExpressions.push_back(ex.differentiate(valueNames[i-1]).optimize().createProgram());
else
valueDerivExpressions.push_back(ex.differentiate("r").optimize().createProgram());
if (i == 0)
valueDerivExpressions[i].push_back(ex.differentiate("r").optimize().createProgram());
else {
for (int j = 0; j < i; j++)
valueDerivExpressions[i].push_back(ex.differentiate(valueNames[j]).optimize().createProgram());
}
}
// Parse the expressions for energy terms.
......
......@@ -584,7 +584,7 @@ private:
std::vector<std::set<int> > exclusions;
std::vector<std::string> particleParameterNames, globalParameterNames, valueNames;
std::vector<Lepton::ExpressionProgram> valueExpressions;
std::vector<Lepton::ExpressionProgram> valueDerivExpressions;
std::vector<std::vector<Lepton::ExpressionProgram> > valueDerivExpressions;
std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes;
std::vector<Lepton::ExpressionProgram> energyExpressions;
std::vector<std::vector<Lepton::ExpressionProgram> > energyDerivExpressions;
......
......@@ -44,7 +44,7 @@ using std::vector;
--------------------------------------------------------------------------------------- */
ReferenceCustomGBIxn::ReferenceCustomGBIxn(const vector<Lepton::ExpressionProgram>& valueExpressions,
const vector<Lepton::ExpressionProgram>& valueDerivExpressions,
const vector<vector<Lepton::ExpressionProgram> > valueDerivExpressions,
const vector<string>& valueNames,
const vector<OpenMM::CustomGBForce::ComputationType>& valueTypes,
const vector<Lepton::ExpressionProgram>& energyExpressions,
......@@ -371,10 +371,11 @@ void ReferenceCustomGBIxn::calculateOnePairChainRule(int atom1, int atom2, RealO
// Evaluate the derivative of each parameter with respect to position and apply forces.
RealOpenMM dVdR = (RealOpenMM) valueDerivExpressions[0].evaluate(variables);
vector<RealOpenMM> dVdR(valueDerivExpressions.size(), 0.0);
dVdR[0] = (RealOpenMM) valueDerivExpressions[0][0].evaluate(variables);
RealOpenMM rinv = 1/r;
for (int i = 0; i < 3; i++) {
RealOpenMM f = dEdV[0][atom1]*dVdR*deltaR[i]*rinv;
RealOpenMM f = dEdV[0][atom1]*dVdR[0]*deltaR[i]*rinv;
forces[atom1][i] -= f;
forces[atom2][i] += f;
}
......@@ -384,9 +385,10 @@ void ReferenceCustomGBIxn::calculateOnePairChainRule(int atom1, int atom2, RealO
variables[valueNames[0]] = values[0][atom1];
for (int i = 1; i < (int) valueNames.size(); i++) {
variables[valueNames[i]] = values[i][atom1];
dVdR *= (RealOpenMM) valueDerivExpressions[i].evaluate(variables);
for (int j = 0; j < i; j++)
dVdR[i] += (RealOpenMM) (valueDerivExpressions[i][j].evaluate(variables)*dVdR[j]);
for (int j = 0; j < 3; j++) {
RealOpenMM f = dEdV[i][atom1]*dVdR*deltaR[j]*rinv;
RealOpenMM f = dEdV[i][atom1]*dVdR[i]*deltaR[j]*rinv;
forces[atom1][j] -= f;
forces[atom2][j] += f;
}
......
......@@ -44,7 +44,7 @@ class ReferenceCustomGBIxn {
RealOpenMM periodicBoxSize[3];
RealOpenMM cutoffDistance;
std::vector<Lepton::ExpressionProgram> valueExpressions;
std::vector<Lepton::ExpressionProgram> valueDerivExpressions;
std::vector<std::vector<Lepton::ExpressionProgram> > valueDerivExpressions;
std::vector<std::string> valueNames;
std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes;
std::vector<Lepton::ExpressionProgram> energyExpressions;
......@@ -221,7 +221,7 @@ class ReferenceCustomGBIxn {
--------------------------------------------------------------------------------------- */
ReferenceCustomGBIxn(const std::vector<Lepton::ExpressionProgram>& valueExpressions,
const std::vector<Lepton::ExpressionProgram>& valueDerivExpressions,
const std::vector<std::vector<Lepton::ExpressionProgram> > valueDerivExpressions,
const std::vector<std::string>& valueNames,
const std::vector<OpenMM::CustomGBForce::ComputationType>& valueTypes,
const std::vector<Lepton::ExpressionProgram>& energyExpressions,
......
......@@ -172,6 +172,34 @@ void testTabulatedFunction(bool interpolating) {
}
}
void testMultipleChainRules() {
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomGBForce* force = new CustomGBForce();
force->addComputedValue("a", "2*r", CustomGBForce::ParticlePair);
force->addComputedValue("b", "a+1", CustomGBForce::SingleParticle);
force->addComputedValue("c", "2*b+a", CustomGBForce::SingleParticle);
force->addEnergyTerm("0.1*a+1*b+10*c", CustomGBForce::SingleParticle); // 0.1*(2*r) + 2*r+1 + 10*(3*a+2) = 0.2*r + 2*r+1 + 40*r+20+20*r = 62.2*r+21
force->addParticle(vector<double>());
force->addParticle(vector<double>());
system.addForce(force);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 1; i < 5; i++) {
positions[1] = Vec3(i, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(124.4, 0, 0), forces[0], 1e-4);
ASSERT_EQUAL_VEC(Vec3(-124.4, 0, 0), forces[1], 1e-4);
ASSERT_EQUAL_TOL(2*(62.2*i+21), state.getPotentialEnergy(), 0.02);
}
}
int main() {
try {
testOBC(GBSAOBCForce::NoCutoff, CustomGBForce::NoCutoff);
......@@ -179,6 +207,7 @@ int main() {
testOBC(GBSAOBCForce::CutoffPeriodic, CustomGBForce::CutoffPeriodic);
testTabulatedFunction(true);
testTabulatedFunction(false);
testMultipleChainRules();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment