Commit a28c8cf6 authored by Peter Eastman's avatar Peter Eastman
Browse files

Added missing chain rule terms to CustomGBForce

parent 7bf52862
...@@ -1824,6 +1824,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -1824,6 +1824,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
// Record derivatives of expressions needed for the chain rule terms. // Record derivatives of expressions needed for the chain rule terms.
vector<vector<Lepton::ParsedExpression> > valueGradientExpressions(force.getNumComputedValues()); vector<vector<Lepton::ParsedExpression> > valueGradientExpressions(force.getNumComputedValues());
vector<vector<Lepton::ParsedExpression> > valueDerivExpressions(force.getNumComputedValues());
needParameterGradient = false; needParameterGradient = false;
for (int i = 1; i < force.getNumComputedValues(); i++) { for (int i = 1; i < force.getNumComputedValues(); i++) {
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize(); Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
...@@ -1832,6 +1833,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -1832,6 +1833,8 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
valueGradientExpressions[i].push_back(ex.differentiate("z").optimize()); valueGradientExpressions[i].push_back(ex.differentiate("z").optimize());
if (!isZeroExpression(valueGradientExpressions[i][0]) || !isZeroExpression(valueGradientExpressions[i][1]) || !isZeroExpression(valueGradientExpressions[i][2])) if (!isZeroExpression(valueGradientExpressions[i][0]) || !isZeroExpression(valueGradientExpressions[i][1]) || !isZeroExpression(valueGradientExpressions[i][2]))
needParameterGradient = true; needParameterGradient = true;
for (int j = 0; j < i; j++)
valueDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize());
} }
vector<vector<Lepton::ParsedExpression> > energyDerivExpressions(force.getNumEnergyTerms()); vector<vector<Lepton::ParsedExpression> > energyDerivExpressions(force.getNumEnergyTerms());
for (int i = 0; i < force.getNumEnergyTerms(); i++) { for (int i = 0; i < force.getNumEnergyTerms(); i++) {
...@@ -2160,16 +2163,31 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2160,16 +2163,31 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
variables[force.getGlobalParameterName(i)] = "globals["+intToString(i)+"]"; variables[force.getGlobalParameterName(i)] = "globals["+intToString(i)+"]";
for (int i = 0; i < force.getNumComputedValues(); i++) for (int i = 0; i < force.getNumComputedValues(); i++)
variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]"); variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]");
map<string, Lepton::ParsedExpression> gradientExpressions;
for (int i = 1; i < force.getNumComputedValues(); i++) { for (int i = 1; i < force.getNumComputedValues(); i++) {
string is = intToString(i);
compute << "float4 dV"<<is<<"dR = (float4) 0;\n";
for (int j = 1; j < i; j++) {
if (!isZeroExpression(valueDerivExpressions[i][j])) {
map<string, Lepton::ParsedExpression> derivExpressions;
string js = intToString(j);
derivExpressions["float dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j];
compute << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, "temp_"+is+"_"+js, prefix+"functionParams");
compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n";
}
}
map<string, Lepton::ParsedExpression> gradientExpressions;
if (!isZeroExpression(valueGradientExpressions[i][0])) if (!isZeroExpression(valueGradientExpressions[i][0]))
gradientExpressions["force.x -= deriv"+energyDerivs->getParameterSuffix(i)+"*"] = valueGradientExpressions[i][0]; gradientExpressions["dV"+is+"dR.x += "] = valueGradientExpressions[i][0];
if (!isZeroExpression(valueGradientExpressions[i][1])) if (!isZeroExpression(valueGradientExpressions[i][1]))
gradientExpressions["force.y -= deriv"+energyDerivs->getParameterSuffix(i)+"*"] = valueGradientExpressions[i][1]; gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
if (!isZeroExpression(valueGradientExpressions[i][2])) if (!isZeroExpression(valueGradientExpressions[i][2]))
gradientExpressions["force.z -= deriv"+energyDerivs->getParameterSuffix(i)+"*"] = valueGradientExpressions[i][2]; gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
compute << OpenCLExpressionUtilities::createExpressions(gradientExpressions, variables, functionDefinitions, "temp", prefix+"functionParams");
}
for (int i = 1; i < force.getNumComputedValues(); i++) {
string is = intToString(i);
compute << "force -= deriv"<<energyDerivs->getParameterSuffix(i)<<"*dV"<<is<<"dR;\n";
} }
compute << OpenCLExpressionUtilities::createExpressions(gradientExpressions, variables, functionDefinitions, "temp", prefix+"functionParams");
map<string, string> replacements; map<string, string> replacements;
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str(); replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
replacements["COMPUTE_FORCES"] = compute.str(); replacements["COMPUTE_FORCES"] = compute.str();
......
...@@ -357,14 +357,25 @@ void ReferenceCustomGBIxn::calculateChainRuleForces(int numAtoms, RealOpenMM** a ...@@ -357,14 +357,25 @@ void ReferenceCustomGBIxn::calculateChainRuleForces(int numAtoms, RealOpenMM** a
variables["x"] = atomCoordinates[i][0]; variables["x"] = atomCoordinates[i][0];
variables["y"] = atomCoordinates[i][1]; variables["y"] = atomCoordinates[i][1];
variables["z"] = atomCoordinates[i][2]; variables["z"] = atomCoordinates[i][2];
vector<RealOpenMM> dVdX(valueDerivExpressions.size(), 0.0);
vector<RealOpenMM> dVdY(valueDerivExpressions.size(), 0.0);
vector<RealOpenMM> dVdZ(valueDerivExpressions.size(), 0.0);
for (int j = 0; j < (int) paramNames.size(); j++) for (int j = 0; j < (int) paramNames.size(); j++)
variables[paramNames[j]] = atomParameters[i][j]; variables[paramNames[j]] = atomParameters[i][j];
for (int j = 1; j < (int) valueNames.size(); j++) { for (int j = 1; j < (int) valueNames.size(); j++) {
variables[valueNames[j-1]] = values[j-1][i]; variables[valueNames[j-1]] = values[j-1][i];
for (int k = 0; k < 3; k++) { for (int k = 1; k < j; k++) {
RealOpenMM gradient = (RealOpenMM) valueGradientExpressions[j][k].evaluate(variables); RealOpenMM dVdV = (RealOpenMM) valueDerivExpressions[j][k].evaluate(variables);
forces[i][k] -= dEdV[j][i]*gradient; dVdX[j] += dVdV*dVdX[k];
dVdY[j] += dVdV*dVdY[k];
dVdZ[j] += dVdV*dVdZ[k];
} }
dVdX[j] += (RealOpenMM) valueGradientExpressions[j][0].evaluate(variables);
dVdY[j] += (RealOpenMM) valueGradientExpressions[j][1].evaluate(variables);
dVdZ[j] += (RealOpenMM) valueGradientExpressions[j][2].evaluate(variables);
forces[i][0] -= dEdV[j][i]*dVdX[j];
forces[i][1] -= dEdV[j][i]*dVdY[j];
forces[i][2] -= dEdV[j][i]*dVdZ[j];
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment