Commit aeaa6f2e authored by peastman's avatar peastman
Browse files

Reference implementation of parameter derivatives for CustomGBForce

parent 0b1f03db
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. * * Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -127,6 +127,10 @@ namespace OpenMM { ...@@ -127,6 +127,10 @@ namespace OpenMM {
* omitted from calculations. This is most often used for particles that are bonded to each other. Even if you specify exclusions, * omitted from calculations. This is most often used for particles that are bonded to each other. Even if you specify exclusions,
* however, you can use the computation type ParticlePairNoExclusions to indicate that exclusions should not be applied to a * however, you can use the computation type ParticlePairNoExclusions to indicate that exclusions should not be applied to a
* particular piece of the computation. * particular piece of the computation.
*
* This class also has the ability to compute derivatives of the potential energy with respect to global parameters.
* Call addEnergyParameterDerivative() to request that the derivative with respect to a particular parameter be
* computed. You can then query its value in a Context by calling getState() on it.
* *
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following * Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta, select. All trigonometric functions * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta, select. All trigonometric functions
...@@ -207,6 +211,13 @@ public: ...@@ -207,6 +211,13 @@ public:
int getNumGlobalParameters() const { int getNumGlobalParameters() const {
return globalParameters.size(); return globalParameters.size();
} }
/**
* Get the number of global parameters with respect to which the derivative of the energy
* should be computed.
*/
int getNumEnergyParameterDerivatives() const {
return energyParameterDerivatives.size();
}
/** /**
* Get the number of tabulated functions that have been defined. * Get the number of tabulated functions that have been defined.
*/ */
...@@ -312,6 +323,21 @@ public: ...@@ -312,6 +323,21 @@ public:
* @param defaultValue the default value of the parameter * @param defaultValue the default value of the parameter
*/ */
void setGlobalParameterDefaultValue(int index, double defaultValue); void setGlobalParameterDefaultValue(int index, double defaultValue);
/**
* Request that this Force compute the derivative of its energy with respect to a global parameter.
* The parameter must have already been added with addGlobalParameter().
*
* @param name the name of the parameter
*/
void addEnergyParameterDerivative(const std::string& name);
/**
* Get the name of a global parameter with respect to which this Force should compute the
* derivative of the energy.
*
* @param index the index of the parameter derivative, between 0 and getNumEnergyParameterDerivatives()
* @return the parameter name
*/
const std::string& getEnergyParameterDerivativeName(int index) const;
/** /**
* Add the nonbonded force parameters for a particle. This should be called once for each particle * Add the nonbonded force parameters for a particle. This should be called once for each particle
* in the System. When it is called for the i'th time, it specifies the parameters for the i'th particle. * in the System. When it is called for the i'th time, it specifies the parameters for the i'th particle.
...@@ -550,6 +576,7 @@ private: ...@@ -550,6 +576,7 @@ private:
std::vector<FunctionInfo> functions; std::vector<FunctionInfo> functions;
std::vector<ComputationInfo> computedValues; std::vector<ComputationInfo> computedValues;
std::vector<ComputationInfo> energyTerms; std::vector<ComputationInfo> energyTerms;
std::vector<int> energyParameterDerivatives;
}; };
/** /**
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. * * Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -111,6 +111,20 @@ void CustomGBForce::setGlobalParameterDefaultValue(int index, double defaultValu ...@@ -111,6 +111,20 @@ void CustomGBForce::setGlobalParameterDefaultValue(int index, double defaultValu
globalParameters[index].defaultValue = defaultValue; globalParameters[index].defaultValue = defaultValue;
} }
void CustomGBForce::addEnergyParameterDerivative(const string& name) {
for (int i = 0; i < globalParameters.size(); i++)
if (name == globalParameters[i].name) {
energyParameterDerivatives.push_back(i);
return;
}
throw OpenMMException(string("addEnergyParameterDerivative: Unknown global parameter '"+name+"'"));
}
const string& CustomGBForce::getEnergyParameterDerivativeName(int index) const {
ASSERT_VALID_INDEX(index, energyParameterDerivatives);
return globalParameters[energyParameterDerivatives[index]].name;
}
int CustomGBForce::addParticle(const vector<double>& parameters) { int CustomGBForce::addParticle(const vector<double>& parameters) {
particles.push_back(ParticleInfo(parameters)); particles.push_back(ParticleInfo(parameters));
return particles.size()-1; return particles.size()-1;
......
/* Portions copyright (c) 2009 Stanford University and Simbios. /* Portions copyright (c) 2009-2016 Stanford University and Simbios.
* Contributors: Peter Eastman * Contributors: Peter Eastman
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
...@@ -47,16 +47,20 @@ class ReferenceCustomGBIxn { ...@@ -47,16 +47,20 @@ class ReferenceCustomGBIxn {
std::vector<Lepton::CompiledExpression> valueExpressions; std::vector<Lepton::CompiledExpression> valueExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > valueDerivExpressions; std::vector<std::vector<Lepton::CompiledExpression> > valueDerivExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > valueGradientExpressions; std::vector<std::vector<Lepton::CompiledExpression> > valueGradientExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > valueParamDerivExpressions;
std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes; std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes;
std::vector<Lepton::CompiledExpression> energyExpressions; std::vector<Lepton::CompiledExpression> energyExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyDerivExpressions; std::vector<std::vector<Lepton::CompiledExpression> > energyDerivExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions; std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyParamDerivExpressions;
std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes; std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes;
std::vector<int> paramIndex; std::vector<int> paramIndex;
std::vector<int> valueIndex; std::vector<int> valueIndex;
std::vector<int> particleParamIndex; std::vector<int> particleParamIndex;
std::vector<int> particleValueIndex; std::vector<int> particleValueIndex;
int rIndex, xIndex, yIndex, zIndex; int rIndex, xIndex, yIndex, zIndex;
std::vector<std::vector<RealOpenMM> > values, dEdV;
std::vector<std::vector<std::vector<RealOpenMM> > > dValuedParam;
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -65,13 +69,11 @@ class ReferenceCustomGBIxn { ...@@ -65,13 +69,11 @@ class ReferenceCustomGBIxn {
@param index the index of the value to compute @param index the index of the value to compute
@param numAtoms number of atoms @param numAtoms number of atoms
@param atomCoordinates atom coordinates @param atomCoordinates atom coordinates
@param values the vector to store computed values into
@param atomParameters atomParameters[atomIndex][paramterIndex] @param atomParameters atomParameters[atomIndex][paramterIndex]
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
void calculateSingleParticleValue(int index, int numAtoms, std::vector<OpenMM::RealVec>& atomCoordinates, std::vector<std::vector<RealOpenMM> >& values, void calculateSingleParticleValue(int index, int numAtoms, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters);
RealOpenMM** atomParameters);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -81,14 +83,12 @@ class ReferenceCustomGBIxn { ...@@ -81,14 +83,12 @@ class ReferenceCustomGBIxn {
@param numAtoms number of atoms @param numAtoms number of atoms
@param atomCoordinates atom coordinates @param atomCoordinates atom coordinates
@param atomParameters atomParameters[atomIndex][paramterIndex] @param atomParameters atomParameters[atomIndex][paramterIndex]
@param values the vector to store computed values into
@param exclusions exclusions[i] is the set of excluded indices for atom i @param exclusions exclusions[i] is the set of excluded indices for atom i
@param useExclusions specifies whether to use exclusions @param useExclusions specifies whether to use exclusions
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
void calculateParticlePairValue(int index, int numAtoms, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters, void calculateParticlePairValue(int index, int numAtoms, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters,
std::vector<std::vector<RealOpenMM> >& values,
const std::vector<std::set<int> >& exclusions, bool useExclusions); const std::vector<std::set<int> >& exclusions, bool useExclusions);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -100,12 +100,10 @@ class ReferenceCustomGBIxn { ...@@ -100,12 +100,10 @@ class ReferenceCustomGBIxn {
@param atom2 the index of the second atom in the pair @param atom2 the index of the second atom in the pair
@param atomCoordinates atom coordinates @param atomCoordinates atom coordinates
@param atomParameters atomParameters[atomIndex][paramterIndex] @param atomParameters atomParameters[atomIndex][paramterIndex]
@param values the vector to store computed values into
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
void calculateOnePairValue(int index, int atom1, int atom2, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters, void calculateOnePairValue(int index, int atom1, int atom2, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters);
std::vector<std::vector<RealOpenMM> >& values);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -114,17 +112,14 @@ class ReferenceCustomGBIxn { ...@@ -114,17 +112,14 @@ class ReferenceCustomGBIxn {
@param index the index of the value to compute @param index the index of the value to compute
@param numAtoms number of atoms @param numAtoms number of atoms
@param atomCoordinates atom coordinates @param atomCoordinates atom coordinates
@param values the vector containing computed values
@param atomParameters atomParameters[atomIndex][paramterIndex] @param atomParameters atomParameters[atomIndex][paramterIndex]
@param forces forces on atoms are added to this @param forces forces on atoms are added to this
@param totalEnergy the energy contribution is added to this @param totalEnergy the energy contribution is added to this
@param dEdV the derivative of energy with respect to computed values is stored in this
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
void calculateSingleParticleEnergyTerm(int index, int numAtoms, std::vector<OpenMM::RealVec>& atomCoordinates, const std::vector<std::vector<RealOpenMM> >& values, void calculateSingleParticleEnergyTerm(int index, int numAtoms, std::vector<OpenMM::RealVec>& atomCoordinates,
RealOpenMM** atomParameters, std::vector<OpenMM::RealVec>& forces, RealOpenMM** atomParameters, std::vector<OpenMM::RealVec>& forces, RealOpenMM* totalEnergy, double* energyParamDerivs);
RealOpenMM* totalEnergy, std::vector<std::vector<RealOpenMM> >& dEdV);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -134,19 +129,16 @@ class ReferenceCustomGBIxn { ...@@ -134,19 +129,16 @@ class ReferenceCustomGBIxn {
@param numAtoms number of atoms @param numAtoms number of atoms
@param atomCoordinates atom coordinates @param atomCoordinates atom coordinates
@param atomParameters atomParameters[atomIndex][paramterIndex] @param atomParameters atomParameters[atomIndex][paramterIndex]
@param values the vector containing computed values
@param exclusions exclusions[i] is the set of excluded indices for atom i @param exclusions exclusions[i] is the set of excluded indices for atom i
@param useExclusions specifies whether to use exclusions @param useExclusions specifies whether to use exclusions
@param forces forces on atoms are added to this @param forces forces on atoms are added to this
@param totalEnergy the energy contribution is added to this @param totalEnergy the energy contribution is added to this
@param dEdV the derivative of energy with respect to computed values is stored in this
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
void calculateParticlePairEnergyTerm(int index, int numAtoms, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters, void calculateParticlePairEnergyTerm(int index, int numAtoms, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters,
const std::vector<std::vector<RealOpenMM> >& values,
const std::vector<std::set<int> >& exclusions, bool useExclusions, const std::vector<std::set<int> >& exclusions, bool useExclusions,
std::vector<OpenMM::RealVec>& forces, RealOpenMM* totalEnergy, std::vector<std::vector<RealOpenMM> >& dEdV); std::vector<OpenMM::RealVec>& forces, RealOpenMM* totalEnergy, double* energyParamDerivs);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -157,16 +149,13 @@ class ReferenceCustomGBIxn { ...@@ -157,16 +149,13 @@ class ReferenceCustomGBIxn {
@param atom2 the index of the second atom in the pair @param atom2 the index of the second atom in the pair
@param atomCoordinates atom coordinates @param atomCoordinates atom coordinates
@param atomParameters atomParameters[atomIndex][paramterIndex] @param atomParameters atomParameters[atomIndex][paramterIndex]
@param values the vector containing computed values
@param forces forces on atoms are added to this @param forces forces on atoms are added to this
@param totalEnergy the energy contribution is added to this @param totalEnergy the energy contribution is added to this
@param dEdV the derivative of energy with respect to computed values is stored in this
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
void calculateOnePairEnergyTerm(int index, int atom1, int atom2, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters, void calculateOnePairEnergyTerm(int index, int atom1, int atom2, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters,
const std::vector<std::vector<RealOpenMM> >& values, std::vector<OpenMM::RealVec>& forces, RealOpenMM* totalEnergy, double* energyParamDerivs);
std::vector<OpenMM::RealVec>& forces, RealOpenMM* totalEnergy, std::vector<std::vector<RealOpenMM> >& dEdV);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -175,17 +164,13 @@ class ReferenceCustomGBIxn { ...@@ -175,17 +164,13 @@ class ReferenceCustomGBIxn {
@param numAtoms number of atoms @param numAtoms number of atoms
@param atomCoordinates atom coordinates @param atomCoordinates atom coordinates
@param atomParameters atomParameters[atomIndex][paramterIndex] @param atomParameters atomParameters[atomIndex][paramterIndex]
@param values the vector containing computed values
@param exclusions exclusions[i] is the set of excluded indices for atom i @param exclusions exclusions[i] is the set of excluded indices for atom i
@param forces forces on atoms are added to this @param forces forces on atoms are added to this
@param dEdV the derivative of energy with respect to computed values is stored in this
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
void calculateChainRuleForces(int numAtoms, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters, void calculateChainRuleForces(int numAtoms, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters,
const std::vector<std::vector<RealOpenMM> >& values, const std::vector<std::set<int> >& exclusions, std::vector<OpenMM::RealVec>& forces);
const std::vector<std::set<int> >& exclusions,
std::vector<OpenMM::RealVec>& forces, std::vector<std::vector<RealOpenMM> >& dEdV);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -195,17 +180,13 @@ class ReferenceCustomGBIxn { ...@@ -195,17 +180,13 @@ class ReferenceCustomGBIxn {
@param atom2 the index of the second atom in the pair @param atom2 the index of the second atom in the pair
@param atomCoordinates atom coordinates @param atomCoordinates atom coordinates
@param atomParameters atomParameters[atomIndex][paramterIndex] @param atomParameters atomParameters[atomIndex][paramterIndex]
@param values the vector containing computed values
@param forces forces on atoms are added to this @param forces forces on atoms are added to this
@param dEdV the derivative of energy with respect to computed values is stored in this
@param isExcluded specifies whether this is an excluded pair @param isExcluded specifies whether this is an excluded pair
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
void calculateOnePairChainRule(int atom1, int atom2, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters, void calculateOnePairChainRule(int atom1, int atom2, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters,
const std::vector<std::vector<RealOpenMM> >& values, std::vector<OpenMM::RealVec>& forces, bool isExcluded);
std::vector<OpenMM::RealVec>& forces, std::vector<std::vector<RealOpenMM> >& dEdV,
bool isExcluded);
public: public:
...@@ -218,11 +199,13 @@ class ReferenceCustomGBIxn { ...@@ -218,11 +199,13 @@ class ReferenceCustomGBIxn {
ReferenceCustomGBIxn(const std::vector<Lepton::CompiledExpression>& valueExpressions, ReferenceCustomGBIxn(const std::vector<Lepton::CompiledExpression>& valueExpressions,
const std::vector<std::vector<Lepton::CompiledExpression> > valueDerivExpressions, const std::vector<std::vector<Lepton::CompiledExpression> > valueDerivExpressions,
const std::vector<std::vector<Lepton::CompiledExpression> > valueGradientExpressions, const std::vector<std::vector<Lepton::CompiledExpression> > valueGradientExpressions,
const std::vector<std::vector<Lepton::CompiledExpression> > valueParamDerivExpressions,
const std::vector<std::string>& valueNames, const std::vector<std::string>& valueNames,
const std::vector<OpenMM::CustomGBForce::ComputationType>& valueTypes, const std::vector<OpenMM::CustomGBForce::ComputationType>& valueTypes,
const std::vector<Lepton::CompiledExpression>& energyExpressions, const std::vector<Lepton::CompiledExpression>& energyExpressions,
const std::vector<std::vector<Lepton::CompiledExpression> > energyDerivExpressions, const std::vector<std::vector<Lepton::CompiledExpression> > energyDerivExpressions,
const std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions, const std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions,
const std::vector<std::vector<Lepton::CompiledExpression> > energyParamDerivExpressions,
const std::vector<OpenMM::CustomGBForce::ComputationType>& energyTypes, const std::vector<OpenMM::CustomGBForce::ComputationType>& energyTypes,
const std::vector<std::string>& parameterNames); const std::vector<std::string>& parameterNames);
...@@ -272,7 +255,7 @@ class ReferenceCustomGBIxn { ...@@ -272,7 +255,7 @@ class ReferenceCustomGBIxn {
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
void calculateIxn(int numberOfAtoms, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters, const std::vector<std::set<int> >& exclusions, void calculateIxn(int numberOfAtoms, std::vector<OpenMM::RealVec>& atomCoordinates, RealOpenMM** atomParameters, const std::vector<std::set<int> >& exclusions,
std::map<std::string, double>& globalParameters, std::vector<OpenMM::RealVec>& forces, RealOpenMM* totalEnergy); std::map<std::string, double>& globalParameters, std::vector<OpenMM::RealVec>& forces, RealOpenMM* totalEnergy, double* energyParamDerivs);
// --------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------
......
...@@ -737,14 +737,16 @@ private: ...@@ -737,14 +737,16 @@ private:
RealOpenMM **particleParamArray; RealOpenMM **particleParamArray;
RealOpenMM nonbondedCutoff; RealOpenMM nonbondedCutoff;
std::vector<std::set<int> > exclusions; std::vector<std::set<int> > exclusions;
std::vector<std::string> particleParameterNames, globalParameterNames, valueNames; std::vector<std::string> particleParameterNames, globalParameterNames, energyParamDerivNames, valueNames;
std::vector<Lepton::CompiledExpression> valueExpressions; std::vector<Lepton::CompiledExpression> valueExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > valueDerivExpressions; std::vector<std::vector<Lepton::CompiledExpression> > valueDerivExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > valueGradientExpressions; std::vector<std::vector<Lepton::CompiledExpression> > valueGradientExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > valueParamDerivExpressions;
std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes; std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes;
std::vector<Lepton::CompiledExpression> energyExpressions; std::vector<Lepton::CompiledExpression> energyExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyDerivExpressions; std::vector<std::vector<Lepton::CompiledExpression> > energyDerivExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions; std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyParamDerivExpressions;
std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes; std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
NeighborList* neighborList; NeighborList* neighborList;
......
...@@ -1359,6 +1359,7 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1359,6 +1359,7 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
valueDerivExpressions.resize(force.getNumComputedValues()); valueDerivExpressions.resize(force.getNumComputedValues());
valueGradientExpressions.resize(force.getNumComputedValues()); valueGradientExpressions.resize(force.getNumComputedValues());
valueParamDerivExpressions.resize(force.getNumComputedValues());
set<string> particleVariables, pairVariables; set<string> particleVariables, pairVariables;
pairVariables.insert("r"); pairVariables.insert("r");
particleVariables.insert("x"); particleVariables.insert("x");
...@@ -1380,17 +1381,22 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1380,17 +1381,22 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
valueTypes.push_back(type); valueTypes.push_back(type);
valueNames.push_back(name); valueNames.push_back(name);
if (i == 0) { if (i == 0) {
valueDerivExpressions[i].push_back(ex.differentiate("r").optimize().createCompiledExpression()); valueDerivExpressions[i].push_back(ex.differentiate("r").createCompiledExpression());
validateVariables(ex.getRootNode(), pairVariables); validateVariables(ex.getRootNode(), pairVariables);
} }
else { else {
valueGradientExpressions[i].push_back(ex.differentiate("x").optimize().createCompiledExpression()); valueGradientExpressions[i].push_back(ex.differentiate("x").createCompiledExpression());
valueGradientExpressions[i].push_back(ex.differentiate("y").optimize().createCompiledExpression()); valueGradientExpressions[i].push_back(ex.differentiate("y").createCompiledExpression());
valueGradientExpressions[i].push_back(ex.differentiate("z").optimize().createCompiledExpression()); valueGradientExpressions[i].push_back(ex.differentiate("z").createCompiledExpression());
for (int j = 0; j < i; j++) for (int j = 0; j < i; j++)
valueDerivExpressions[i].push_back(ex.differentiate(valueNames[j]).optimize().createCompiledExpression()); valueDerivExpressions[i].push_back(ex.differentiate(valueNames[j]).createCompiledExpression());
validateVariables(ex.getRootNode(), particleVariables); validateVariables(ex.getRootNode(), particleVariables);
} }
for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++) {
string param = force.getEnergyParameterDerivativeName(j);
energyParamDerivNames.push_back(param);
valueParamDerivExpressions[i].push_back(ex.differentiate(param).createCompiledExpression());
}
particleVariables.insert(name); particleVariables.insert(name);
pairVariables.insert(name+"1"); pairVariables.insert(name+"1");
pairVariables.insert(name+"2"); pairVariables.insert(name+"2");
...@@ -1400,6 +1406,7 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1400,6 +1406,7 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
energyDerivExpressions.resize(force.getNumEnergyTerms()); energyDerivExpressions.resize(force.getNumEnergyTerms());
energyGradientExpressions.resize(force.getNumEnergyTerms()); energyGradientExpressions.resize(force.getNumEnergyTerms());
energyParamDerivExpressions.resize(force.getNumEnergyTerms());
for (int i = 0; i < force.getNumEnergyTerms(); i++) { for (int i = 0; i < force.getNumEnergyTerms(); i++) {
string expression; string expression;
CustomGBForce::ComputationType type; CustomGBForce::ComputationType type;
...@@ -1408,21 +1415,23 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1408,21 +1415,23 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
energyExpressions.push_back(ex.createCompiledExpression()); energyExpressions.push_back(ex.createCompiledExpression());
energyTypes.push_back(type); energyTypes.push_back(type);
if (type != CustomGBForce::SingleParticle) if (type != CustomGBForce::SingleParticle)
energyDerivExpressions[i].push_back(ex.differentiate("r").optimize().createCompiledExpression()); energyDerivExpressions[i].push_back(ex.differentiate("r").createCompiledExpression());
for (int j = 0; j < force.getNumComputedValues(); j++) { for (int j = 0; j < force.getNumComputedValues(); j++) {
if (type == CustomGBForce::SingleParticle) { if (type == CustomGBForce::SingleParticle) {
energyDerivExpressions[i].push_back(ex.differentiate(valueNames[j]).optimize().createCompiledExpression()); energyDerivExpressions[i].push_back(ex.differentiate(valueNames[j]).createCompiledExpression());
energyGradientExpressions[i].push_back(ex.differentiate("x").optimize().createCompiledExpression()); energyGradientExpressions[i].push_back(ex.differentiate("x").createCompiledExpression());
energyGradientExpressions[i].push_back(ex.differentiate("y").optimize().createCompiledExpression()); energyGradientExpressions[i].push_back(ex.differentiate("y").createCompiledExpression());
energyGradientExpressions[i].push_back(ex.differentiate("z").optimize().createCompiledExpression()); energyGradientExpressions[i].push_back(ex.differentiate("z").createCompiledExpression());
validateVariables(ex.getRootNode(), particleVariables); validateVariables(ex.getRootNode(), particleVariables);
} }
else { else {
energyDerivExpressions[i].push_back(ex.differentiate(valueNames[j]+"1").optimize().createCompiledExpression()); energyDerivExpressions[i].push_back(ex.differentiate(valueNames[j]+"1").createCompiledExpression());
energyDerivExpressions[i].push_back(ex.differentiate(valueNames[j]+"2").optimize().createCompiledExpression()); energyDerivExpressions[i].push_back(ex.differentiate(valueNames[j]+"2").createCompiledExpression());
validateVariables(ex.getRootNode(), pairVariables); validateVariables(ex.getRootNode(), pairVariables);
} }
} }
for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
energyParamDerivExpressions[i].push_back(ex.differentiate(force.getEnergyParameterDerivativeName(j)).createCompiledExpression());
} }
// Delete the custom functions. // Delete the custom functions.
...@@ -1435,8 +1444,8 @@ double ReferenceCalcCustomGBForceKernel::execute(ContextImpl& context, bool incl ...@@ -1435,8 +1444,8 @@ double ReferenceCalcCustomGBForceKernel::execute(ContextImpl& context, bool incl
vector<RealVec>& posData = extractPositions(context); vector<RealVec>& posData = extractPositions(context);
vector<RealVec>& forceData = extractForces(context); vector<RealVec>& forceData = extractForces(context);
RealOpenMM energy = 0; RealOpenMM energy = 0;
ReferenceCustomGBIxn ixn(valueExpressions, valueDerivExpressions, valueGradientExpressions, valueNames, valueTypes, energyExpressions, ReferenceCustomGBIxn ixn(valueExpressions, valueDerivExpressions, valueGradientExpressions, valueParamDerivExpressions, valueNames, valueTypes,
energyDerivExpressions, energyGradientExpressions, energyTypes, particleParameterNames); energyExpressions, energyDerivExpressions, energyGradientExpressions, energyParamDerivExpressions, energyTypes, particleParameterNames);
bool periodic = (nonbondedMethod == CutoffPeriodic); bool periodic = (nonbondedMethod == CutoffPeriodic);
if (periodic) if (periodic)
ixn.setPeriodic(extractBoxVectors(context)); ixn.setPeriodic(extractBoxVectors(context));
...@@ -1447,7 +1456,11 @@ double ReferenceCalcCustomGBForceKernel::execute(ContextImpl& context, bool incl ...@@ -1447,7 +1456,11 @@ double ReferenceCalcCustomGBForceKernel::execute(ContextImpl& context, bool incl
map<string, double> globalParameters; map<string, double> globalParameters;
for (int i = 0; i < (int) globalParameterNames.size(); i++) for (int i = 0; i < (int) globalParameterNames.size(); i++)
globalParameters[globalParameterNames[i]] = context.getParameter(globalParameterNames[i]); globalParameters[globalParameterNames[i]] = context.getParameter(globalParameterNames[i]);
ixn.calculateIxn(numParticles, posData, particleParamArray, exclusions, globalParameters, forceData, includeEnergy ? &energy : NULL); vector<double> energyParamDerivValues(energyParamDerivNames.size()+1, 0.0);
ixn.calculateIxn(numParticles, posData, particleParamArray, exclusions, globalParameters, forceData, includeEnergy ? &energy : NULL, &energyParamDerivValues[0]);
map<string, double>& energyParamDerivs = extractEnergyParameterDerivatives(context);
for (int i = 0; i < energyParamDerivNames.size(); i++)
energyParamDerivs[energyParamDerivNames[i]] += energyParamDerivValues[i];
return energy; return energy;
} }
......
/* Portions copyright (c) 2009 Stanford University and Simbios. /* Portions copyright (c) 2009-2016 Stanford University and Simbios.
* Contributors: Peter Eastman * Contributors: Peter Eastman
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
...@@ -45,15 +45,17 @@ using namespace OpenMM; ...@@ -45,15 +45,17 @@ using namespace OpenMM;
ReferenceCustomGBIxn::ReferenceCustomGBIxn(const vector<Lepton::CompiledExpression>& valueExpressions, ReferenceCustomGBIxn::ReferenceCustomGBIxn(const vector<Lepton::CompiledExpression>& valueExpressions,
const vector<vector<Lepton::CompiledExpression> > valueDerivExpressions, const vector<vector<Lepton::CompiledExpression> > valueDerivExpressions,
const vector<vector<Lepton::CompiledExpression> > valueGradientExpressions, const vector<vector<Lepton::CompiledExpression> > valueGradientExpressions,
const vector<vector<Lepton::CompiledExpression> > valueParamDerivExpressions,
const vector<string>& valueNames, const vector<string>& valueNames,
const vector<OpenMM::CustomGBForce::ComputationType>& valueTypes, const vector<OpenMM::CustomGBForce::ComputationType>& valueTypes,
const vector<Lepton::CompiledExpression>& energyExpressions, const vector<Lepton::CompiledExpression>& energyExpressions,
const vector<vector<Lepton::CompiledExpression> > energyDerivExpressions, const vector<vector<Lepton::CompiledExpression> > energyDerivExpressions,
const vector<vector<Lepton::CompiledExpression> > energyGradientExpressions, const vector<vector<Lepton::CompiledExpression> > energyGradientExpressions,
const vector<vector<Lepton::CompiledExpression> > energyParamDerivExpressions,
const vector<OpenMM::CustomGBForce::ComputationType>& energyTypes, const vector<OpenMM::CustomGBForce::ComputationType>& energyTypes,
const vector<string>& parameterNames) : const vector<string>& parameterNames) :
cutoff(false), periodic(false), valueExpressions(valueExpressions), valueDerivExpressions(valueDerivExpressions), valueGradientExpressions(valueGradientExpressions), cutoff(false), periodic(false), valueExpressions(valueExpressions), valueDerivExpressions(valueDerivExpressions), valueGradientExpressions(valueGradientExpressions), valueParamDerivExpressions(valueParamDerivExpressions),
valueTypes(valueTypes), energyExpressions(energyExpressions), energyDerivExpressions(energyDerivExpressions), energyGradientExpressions(energyGradientExpressions), valueTypes(valueTypes), energyExpressions(energyExpressions), energyDerivExpressions(energyDerivExpressions), energyGradientExpressions(energyGradientExpressions), energyParamDerivExpressions(energyParamDerivExpressions),
energyTypes(energyTypes) { energyTypes(energyTypes) {
for (int i = 0; i < this->valueExpressions.size(); i++) for (int i = 0; i < this->valueExpressions.size(); i++)
...@@ -64,6 +66,9 @@ ReferenceCustomGBIxn::ReferenceCustomGBIxn(const vector<Lepton::CompiledExpressi ...@@ -64,6 +66,9 @@ ReferenceCustomGBIxn::ReferenceCustomGBIxn(const vector<Lepton::CompiledExpressi
for (int i = 0; i < this->valueGradientExpressions.size(); i++) for (int i = 0; i < this->valueGradientExpressions.size(); i++)
for (int j = 0; j < this->valueGradientExpressions[i].size(); j++) for (int j = 0; j < this->valueGradientExpressions[i].size(); j++)
expressionSet.registerExpression(this->valueGradientExpressions[i][j]); expressionSet.registerExpression(this->valueGradientExpressions[i][j]);
for (int i = 0; i < this->valueParamDerivExpressions.size(); i++)
for (int j = 0; j < this->valueParamDerivExpressions[i].size(); j++)
expressionSet.registerExpression(this->valueParamDerivExpressions[i][j]);
for (int i = 0; i < this->energyExpressions.size(); i++) for (int i = 0; i < this->energyExpressions.size(); i++)
expressionSet.registerExpression(this->energyExpressions[i]); expressionSet.registerExpression(this->energyExpressions[i]);
for (int i = 0; i < this->energyDerivExpressions.size(); i++) for (int i = 0; i < this->energyDerivExpressions.size(); i++)
...@@ -72,6 +77,9 @@ ReferenceCustomGBIxn::ReferenceCustomGBIxn(const vector<Lepton::CompiledExpressi ...@@ -72,6 +77,9 @@ ReferenceCustomGBIxn::ReferenceCustomGBIxn(const vector<Lepton::CompiledExpressi
for (int i = 0; i < this->energyGradientExpressions.size(); i++) for (int i = 0; i < this->energyGradientExpressions.size(); i++)
for (int j = 0; j < this->energyGradientExpressions[i].size(); j++) for (int j = 0; j < this->energyGradientExpressions[i].size(); j++)
expressionSet.registerExpression(this->energyGradientExpressions[i][j]); expressionSet.registerExpression(this->energyGradientExpressions[i][j]);
for (int i = 0; i < this->energyParamDerivExpressions.size(); i++)
for (int j = 0; j < this->energyParamDerivExpressions[i].size(); j++)
expressionSet.registerExpression(this->energyParamDerivExpressions[i][j]);
rIndex = expressionSet.getVariableIndex("r"); rIndex = expressionSet.getVariableIndex("r");
xIndex = expressionSet.getVariableIndex("x"); xIndex = expressionSet.getVariableIndex("x");
yIndex = expressionSet.getVariableIndex("y"); yIndex = expressionSet.getVariableIndex("y");
...@@ -144,42 +152,48 @@ ReferenceCustomGBIxn::~ReferenceCustomGBIxn() { ...@@ -144,42 +152,48 @@ ReferenceCustomGBIxn::~ReferenceCustomGBIxn() {
void ReferenceCustomGBIxn::calculateIxn(int numberOfAtoms, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters, void ReferenceCustomGBIxn::calculateIxn(int numberOfAtoms, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters,
const vector<set<int> >& exclusions, map<string, double>& globalParameters, vector<RealVec>& forces, const vector<set<int> >& exclusions, map<string, double>& globalParameters, vector<RealVec>& forces,
RealOpenMM* totalEnergy) { RealOpenMM* totalEnergy, double* energyParamDerivs) {
for (map<string, double>::const_iterator iter = globalParameters.begin(); iter != globalParameters.end(); ++iter) for (map<string, double>::const_iterator iter = globalParameters.begin(); iter != globalParameters.end(); ++iter)
expressionSet.setVariable(expressionSet.getVariableIndex(iter->first), iter->second); expressionSet.setVariable(expressionSet.getVariableIndex(iter->first), iter->second);
// Initialize arrays for storing values.
int numValues = valueTypes.size();
int numDerivs = valueParamDerivExpressions[0].size();
values.resize(numValues);
dEdV.resize(numValues, vector<RealOpenMM>(numberOfAtoms, 0.0));
dValuedParam.resize(numValues);
for (int i = 0; i < numValues; i++)
dValuedParam[i].resize(numDerivs, vector<RealOpenMM>(numberOfAtoms, 0.0));
// First calculate the computed values. // First calculate the computed values.
int numValues = valueTypes.size();
vector<vector<RealOpenMM> > values(numValues);
for (int valueIndex = 0; valueIndex < numValues; valueIndex++) { for (int valueIndex = 0; valueIndex < numValues; valueIndex++) {
if (valueTypes[valueIndex] == OpenMM::CustomGBForce::SingleParticle) if (valueTypes[valueIndex] == OpenMM::CustomGBForce::SingleParticle)
calculateSingleParticleValue(valueIndex, numberOfAtoms, atomCoordinates, values, atomParameters); calculateSingleParticleValue(valueIndex, numberOfAtoms, atomCoordinates, atomParameters);
else if (valueTypes[valueIndex] == OpenMM::CustomGBForce::ParticlePair) else if (valueTypes[valueIndex] == OpenMM::CustomGBForce::ParticlePair)
calculateParticlePairValue(valueIndex, numberOfAtoms, atomCoordinates, atomParameters, values, exclusions, true); calculateParticlePairValue(valueIndex, numberOfAtoms, atomCoordinates, atomParameters, exclusions, true);
else else
calculateParticlePairValue(valueIndex, numberOfAtoms, atomCoordinates, atomParameters, values, exclusions, false); calculateParticlePairValue(valueIndex, numberOfAtoms, atomCoordinates, atomParameters, exclusions, false);
} }
// Now calculate the energy and its derivates. // Now calculate the energy and its derivates.
vector<vector<RealOpenMM> > dEdV(numValues, vector<RealOpenMM>(numberOfAtoms, (RealOpenMM) 0));
for (int termIndex = 0; termIndex < (int) energyExpressions.size(); termIndex++) { for (int termIndex = 0; termIndex < (int) energyExpressions.size(); termIndex++) {
if (energyTypes[termIndex] == OpenMM::CustomGBForce::SingleParticle) if (energyTypes[termIndex] == OpenMM::CustomGBForce::SingleParticle)
calculateSingleParticleEnergyTerm(termIndex, numberOfAtoms, atomCoordinates, values, atomParameters, forces, totalEnergy, dEdV); calculateSingleParticleEnergyTerm(termIndex, numberOfAtoms, atomCoordinates, atomParameters, forces, totalEnergy, energyParamDerivs);
else if (energyTypes[termIndex] == OpenMM::CustomGBForce::ParticlePair) else if (energyTypes[termIndex] == OpenMM::CustomGBForce::ParticlePair)
calculateParticlePairEnergyTerm(termIndex, numberOfAtoms, atomCoordinates, atomParameters, values, exclusions, true, forces, totalEnergy, dEdV); calculateParticlePairEnergyTerm(termIndex, numberOfAtoms, atomCoordinates, atomParameters, exclusions, true, forces, totalEnergy, energyParamDerivs);
else else
calculateParticlePairEnergyTerm(termIndex, numberOfAtoms, atomCoordinates, atomParameters, values, exclusions, false, forces, totalEnergy, dEdV); calculateParticlePairEnergyTerm(termIndex, numberOfAtoms, atomCoordinates, atomParameters, exclusions, false, forces, totalEnergy, energyParamDerivs);
} }
// Apply the chain rule to evaluate forces. // Apply the chain rule to evaluate forces.
calculateChainRuleForces(numberOfAtoms, atomCoordinates, atomParameters, values, exclusions, forces, dEdV); calculateChainRuleForces(numberOfAtoms, atomCoordinates, atomParameters, exclusions, forces);
} }
void ReferenceCustomGBIxn::calculateSingleParticleValue(int index, int numAtoms, vector<RealVec>& atomCoordinates, vector<vector<RealOpenMM> >& values, void ReferenceCustomGBIxn::calculateSingleParticleValue(int index, int numAtoms, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters) {
RealOpenMM** atomParameters) {
values[index].resize(numAtoms); values[index].resize(numAtoms);
for (int i = 0; i < numAtoms; i++) { for (int i = 0; i < numAtoms; i++) {
expressionSet.setVariable(xIndex, atomCoordinates[i][0]); expressionSet.setVariable(xIndex, atomCoordinates[i][0]);
...@@ -190,11 +204,21 @@ void ReferenceCustomGBIxn::calculateSingleParticleValue(int index, int numAtoms, ...@@ -190,11 +204,21 @@ void ReferenceCustomGBIxn::calculateSingleParticleValue(int index, int numAtoms,
for (int j = 0; j < index; j++) for (int j = 0; j < index; j++)
expressionSet.setVariable(valueIndex[j], values[j][i]); expressionSet.setVariable(valueIndex[j], values[j][i]);
values[index][i] = (RealOpenMM) valueExpressions[index].evaluate(); values[index][i] = (RealOpenMM) valueExpressions[index].evaluate();
for (int j = 0; j < valueParamDerivExpressions[index].size(); j++)
dValuedParam[index][j][i] += valueParamDerivExpressions[index][j].evaluate();
// Calculate derivatives with respect to parameters.
for (int j = 0; j < index; j++) {
RealOpenMM dVdV = valueDerivExpressions[index][j].evaluate();
for (int k = 0; k < valueParamDerivExpressions[index].size(); k++)
dValuedParam[index][k][i] += dVdV*dValuedParam[j][k][i];
}
} }
} }
void ReferenceCustomGBIxn::calculateParticlePairValue(int index, int numAtoms, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters, void ReferenceCustomGBIxn::calculateParticlePairValue(int index, int numAtoms, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters,
vector<vector<RealOpenMM> >& values, const vector<set<int> >& exclusions, bool useExclusions) { const vector<set<int> >& exclusions, bool useExclusions) {
values[index].resize(numAtoms); values[index].resize(numAtoms);
for (int i = 0; i < numAtoms; i++) for (int i = 0; i < numAtoms; i++)
values[index][i] = (RealOpenMM) 0.0; values[index][i] = (RealOpenMM) 0.0;
...@@ -205,8 +229,8 @@ void ReferenceCustomGBIxn::calculateParticlePairValue(int index, int numAtoms, v ...@@ -205,8 +229,8 @@ void ReferenceCustomGBIxn::calculateParticlePairValue(int index, int numAtoms, v
OpenMM::AtomPair pair = (*neighborList)[i]; OpenMM::AtomPair pair = (*neighborList)[i];
if (useExclusions && exclusions[pair.first].find(pair.second) != exclusions[pair.first].end()) if (useExclusions && exclusions[pair.first].find(pair.second) != exclusions[pair.first].end())
continue; continue;
calculateOnePairValue(index, pair.first, pair.second, atomCoordinates, atomParameters, values); calculateOnePairValue(index, pair.first, pair.second, atomCoordinates, atomParameters);
calculateOnePairValue(index, pair.second, pair.first, atomCoordinates, atomParameters, values); calculateOnePairValue(index, pair.second, pair.first, atomCoordinates, atomParameters);
} }
} }
else { else {
...@@ -216,15 +240,14 @@ void ReferenceCustomGBIxn::calculateParticlePairValue(int index, int numAtoms, v ...@@ -216,15 +240,14 @@ void ReferenceCustomGBIxn::calculateParticlePairValue(int index, int numAtoms, v
for (int j = i+1; j < numAtoms; j++) { for (int j = i+1; j < numAtoms; j++) {
if (useExclusions && exclusions[i].find(j) != exclusions[i].end()) if (useExclusions && exclusions[i].find(j) != exclusions[i].end())
continue; continue;
calculateOnePairValue(index, i, j, atomCoordinates, atomParameters, values); calculateOnePairValue(index, i, j, atomCoordinates, atomParameters);
calculateOnePairValue(index, j, i, atomCoordinates, atomParameters, values); calculateOnePairValue(index, j, i, atomCoordinates, atomParameters);
} }
} }
} }
} }
void ReferenceCustomGBIxn::calculateOnePairValue(int index, int atom1, int atom2, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters, void ReferenceCustomGBIxn::calculateOnePairValue(int index, int atom1, int atom2, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters) {
vector<vector<RealOpenMM> >& values) {
RealOpenMM deltaR[ReferenceForce::LastDeltaRIndex]; RealOpenMM deltaR[ReferenceForce::LastDeltaRIndex];
if (periodic) if (periodic)
ReferenceForce::getDeltaRPeriodic(atomCoordinates[atom2], atomCoordinates[atom1], periodicBoxVectors, deltaR); ReferenceForce::getDeltaRPeriodic(atomCoordinates[atom2], atomCoordinates[atom1], periodicBoxVectors, deltaR);
...@@ -243,11 +266,15 @@ void ReferenceCustomGBIxn::calculateOnePairValue(int index, int atom1, int atom2 ...@@ -243,11 +266,15 @@ void ReferenceCustomGBIxn::calculateOnePairValue(int index, int atom1, int atom2
expressionSet.setVariable(particleValueIndex[i*2+1], values[i][atom2]); expressionSet.setVariable(particleValueIndex[i*2+1], values[i][atom2]);
} }
values[index][atom1] += (RealOpenMM) valueExpressions[index].evaluate(); values[index][atom1] += (RealOpenMM) valueExpressions[index].evaluate();
// Calculate derivatives with respect to parameters.
for (int i = 0; i < valueParamDerivExpressions[index].size(); i++)
dValuedParam[index][i][atom1] += valueParamDerivExpressions[index][i].evaluate();
} }
void ReferenceCustomGBIxn::calculateSingleParticleEnergyTerm(int index, int numAtoms, vector<RealVec>& atomCoordinates, const vector<vector<RealOpenMM> >& values, void ReferenceCustomGBIxn::calculateSingleParticleEnergyTerm(int index, int numAtoms, vector<RealVec>& atomCoordinates,
RealOpenMM** atomParameters, vector<RealVec>& forces, RealOpenMM* totalEnergy, RealOpenMM** atomParameters, vector<RealVec>& forces, RealOpenMM* totalEnergy, double* energyParamDerivs) {
vector<vector<RealOpenMM> >& dEdV) {
for (int i = 0; i < numAtoms; i++) { for (int i = 0; i < numAtoms; i++) {
expressionSet.setVariable(xIndex, atomCoordinates[i][0]); expressionSet.setVariable(xIndex, atomCoordinates[i][0]);
expressionSet.setVariable(yIndex, atomCoordinates[i][1]); expressionSet.setVariable(yIndex, atomCoordinates[i][1]);
...@@ -256,6 +283,9 @@ void ReferenceCustomGBIxn::calculateSingleParticleEnergyTerm(int index, int numA ...@@ -256,6 +283,9 @@ void ReferenceCustomGBIxn::calculateSingleParticleEnergyTerm(int index, int numA
expressionSet.setVariable(paramIndex[j], atomParameters[i][j]); expressionSet.setVariable(paramIndex[j], atomParameters[i][j]);
for (int j = 0; j < valueIndex.size(); j++) for (int j = 0; j < valueIndex.size(); j++)
expressionSet.setVariable(valueIndex[j], values[j][i]); expressionSet.setVariable(valueIndex[j], values[j][i]);
// Compute energy and force.
if (totalEnergy != NULL) if (totalEnergy != NULL)
*totalEnergy += (RealOpenMM) energyExpressions[index].evaluate(); *totalEnergy += (RealOpenMM) energyExpressions[index].evaluate();
for (int j = 0; j < (int) valueIndex.size(); j++) for (int j = 0; j < (int) valueIndex.size(); j++)
...@@ -263,12 +293,19 @@ void ReferenceCustomGBIxn::calculateSingleParticleEnergyTerm(int index, int numA ...@@ -263,12 +293,19 @@ void ReferenceCustomGBIxn::calculateSingleParticleEnergyTerm(int index, int numA
forces[i][0] -= (RealOpenMM) energyGradientExpressions[index][0].evaluate(); forces[i][0] -= (RealOpenMM) energyGradientExpressions[index][0].evaluate();
forces[i][1] -= (RealOpenMM) energyGradientExpressions[index][1].evaluate(); forces[i][1] -= (RealOpenMM) energyGradientExpressions[index][1].evaluate();
forces[i][2] -= (RealOpenMM) energyGradientExpressions[index][2].evaluate(); forces[i][2] -= (RealOpenMM) energyGradientExpressions[index][2].evaluate();
// Compute derivatives with respect to parameters.
for (int k = 0; k < energyParamDerivExpressions[index].size(); k++) {
energyParamDerivs[k] += energyParamDerivExpressions[index][k].evaluate();
for (int j = 0; j < (int) valueIndex.size(); j++)
energyParamDerivs[k] += dEdV[j][i]*dValuedParam[j][k][i];
}
} }
} }
void ReferenceCustomGBIxn::calculateParticlePairEnergyTerm(int index, int numAtoms, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters, void ReferenceCustomGBIxn::calculateParticlePairEnergyTerm(int index, int numAtoms, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters,
const vector<vector<RealOpenMM> >& values, const vector<set<int> >& exclusions, bool useExclusions, const vector<set<int> >& exclusions, bool useExclusions, vector<RealVec>& forces, RealOpenMM* totalEnergy, double* energyParamDerivs) {
vector<RealVec>& forces, RealOpenMM* totalEnergy, vector<vector<RealOpenMM> >& dEdV) {
if (cutoff) { if (cutoff) {
// Loop over all pairs in the neighbor list. // Loop over all pairs in the neighbor list.
...@@ -276,7 +313,7 @@ void ReferenceCustomGBIxn::calculateParticlePairEnergyTerm(int index, int numAto ...@@ -276,7 +313,7 @@ void ReferenceCustomGBIxn::calculateParticlePairEnergyTerm(int index, int numAto
OpenMM::AtomPair pair = (*neighborList)[i]; OpenMM::AtomPair pair = (*neighborList)[i];
if (useExclusions && exclusions[pair.first].find(pair.second) != exclusions[pair.first].end()) if (useExclusions && exclusions[pair.first].find(pair.second) != exclusions[pair.first].end())
continue; continue;
calculateOnePairEnergyTerm(index, pair.first, pair.second, atomCoordinates, atomParameters, values, forces, totalEnergy, dEdV); calculateOnePairEnergyTerm(index, pair.first, pair.second, atomCoordinates, atomParameters, forces, totalEnergy, energyParamDerivs);
} }
} }
else { else {
...@@ -286,15 +323,14 @@ void ReferenceCustomGBIxn::calculateParticlePairEnergyTerm(int index, int numAto ...@@ -286,15 +323,14 @@ void ReferenceCustomGBIxn::calculateParticlePairEnergyTerm(int index, int numAto
for (int j = i+1; j < numAtoms; j++) { for (int j = i+1; j < numAtoms; j++) {
if (useExclusions && exclusions[i].find(j) != exclusions[i].end()) if (useExclusions && exclusions[i].find(j) != exclusions[i].end())
continue; continue;
calculateOnePairEnergyTerm(index, i, j, atomCoordinates, atomParameters, values, forces, totalEnergy, dEdV); calculateOnePairEnergyTerm(index, i, j, atomCoordinates, atomParameters, forces, totalEnergy, energyParamDerivs);
} }
} }
} }
} }
void ReferenceCustomGBIxn::calculateOnePairEnergyTerm(int index, int atom1, int atom2, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters, void ReferenceCustomGBIxn::calculateOnePairEnergyTerm(int index, int atom1, int atom2, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters,
const vector<vector<RealOpenMM> >& values, vector<RealVec>& forces, RealOpenMM* totalEnergy, vector<RealVec>& forces, RealOpenMM* totalEnergy, double* energyParamDerivs) {
vector<vector<RealOpenMM> >& dEdV) {
// Compute the displacement. // Compute the displacement.
RealOpenMM deltaR[ReferenceForce::LastDeltaRIndex]; RealOpenMM deltaR[ReferenceForce::LastDeltaRIndex];
...@@ -332,18 +368,23 @@ void ReferenceCustomGBIxn::calculateOnePairEnergyTerm(int index, int atom1, int ...@@ -332,18 +368,23 @@ void ReferenceCustomGBIxn::calculateOnePairEnergyTerm(int index, int atom1, int
dEdV[i][atom1] += (RealOpenMM) energyDerivExpressions[index][2*i+1].evaluate(); dEdV[i][atom1] += (RealOpenMM) energyDerivExpressions[index][2*i+1].evaluate();
dEdV[i][atom2] += (RealOpenMM) energyDerivExpressions[index][2*i+2].evaluate(); dEdV[i][atom2] += (RealOpenMM) energyDerivExpressions[index][2*i+2].evaluate();
} }
// Compute derivatives with respect to parameters.
for (int i = 0; i < energyParamDerivExpressions[index].size(); i++)
energyParamDerivs[i] += energyParamDerivExpressions[index][i].evaluate();
} }
void ReferenceCustomGBIxn::calculateChainRuleForces(int numAtoms, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters, void ReferenceCustomGBIxn::calculateChainRuleForces(int numAtoms, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters,
const vector<vector<RealOpenMM> >& values, const vector<set<int> >& exclusions, vector<RealVec>& forces, vector<vector<RealOpenMM> >& dEdV) { const vector<set<int> >& exclusions, vector<RealVec>& forces) {
if (cutoff) { if (cutoff) {
// Loop over all pairs in the neighbor list. // Loop over all pairs in the neighbor list.
for (int i = 0; i < (int) neighborList->size(); i++) { for (int i = 0; i < (int) neighborList->size(); i++) {
OpenMM::AtomPair pair = (*neighborList)[i]; OpenMM::AtomPair pair = (*neighborList)[i];
bool isExcluded = (exclusions[pair.first].find(pair.second) != exclusions[pair.first].end()); bool isExcluded = (exclusions[pair.first].find(pair.second) != exclusions[pair.first].end());
calculateOnePairChainRule(pair.first, pair.second, atomCoordinates, atomParameters, values, forces, dEdV, isExcluded); calculateOnePairChainRule(pair.first, pair.second, atomCoordinates, atomParameters, forces, isExcluded);
calculateOnePairChainRule(pair.second, pair.first, atomCoordinates, atomParameters, values, forces, dEdV, isExcluded); calculateOnePairChainRule(pair.second, pair.first, atomCoordinates, atomParameters, forces, isExcluded);
} }
} }
else { else {
...@@ -352,8 +393,8 @@ void ReferenceCustomGBIxn::calculateChainRuleForces(int numAtoms, vector<RealVec ...@@ -352,8 +393,8 @@ void ReferenceCustomGBIxn::calculateChainRuleForces(int numAtoms, vector<RealVec
for (int i = 0; i < numAtoms; i++) { for (int i = 0; i < numAtoms; i++) {
for (int j = i+1; j < numAtoms; j++) { for (int j = i+1; j < numAtoms; j++) {
bool isExcluded = (exclusions[i].find(j) != exclusions[i].end()); bool isExcluded = (exclusions[i].find(j) != exclusions[i].end());
calculateOnePairChainRule(i, j, atomCoordinates, atomParameters, values, forces, dEdV, isExcluded); calculateOnePairChainRule(i, j, atomCoordinates, atomParameters, forces, isExcluded);
calculateOnePairChainRule(j, i, atomCoordinates, atomParameters, values, forces, dEdV, isExcluded); calculateOnePairChainRule(j, i, atomCoordinates, atomParameters, forces, isExcluded);
} }
} }
} }
...@@ -388,8 +429,7 @@ void ReferenceCustomGBIxn::calculateChainRuleForces(int numAtoms, vector<RealVec ...@@ -388,8 +429,7 @@ void ReferenceCustomGBIxn::calculateChainRuleForces(int numAtoms, vector<RealVec
} }
void ReferenceCustomGBIxn::calculateOnePairChainRule(int atom1, int atom2, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters, void ReferenceCustomGBIxn::calculateOnePairChainRule(int atom1, int atom2, vector<RealVec>& atomCoordinates, RealOpenMM** atomParameters,
const vector<vector<RealOpenMM> >& values, vector<RealVec>& forces, vector<RealVec>& forces, bool isExcluded) {
vector<vector<RealOpenMM> >& dEdV, bool isExcluded) {
// Compute the displacement. // Compute the displacement.
RealOpenMM deltaR[ReferenceForce::LastDeltaRIndex]; RealOpenMM deltaR[ReferenceForce::LastDeltaRIndex];
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. * * Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -488,6 +488,54 @@ void testIllegalVariable() { ...@@ -488,6 +488,54 @@ void testIllegalVariable() {
ASSERT(threwException); ASSERT(threwException);
} }
void testEnergyParameterDerivatives() {
// Create a box of particles.
const int numParticles = 30;
const int numParameters = 4;
const double boxSize = 2.0;
const double delta = 1e-3;
const string paramNames[] = {"A", "B", "C", "D"};
const double paramValues[] = {0.8, 2.1, 3.2, 1.3};
System system;
system.setDefaultPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
CustomGBForce* force = new CustomGBForce();
system.addForce(force);
force->addComputedValue("a", "0.5*(r-A)^2", CustomGBForce::ParticlePair);
force->addComputedValue("b", "a+B", CustomGBForce::SingleParticle);
force->addEnergyTerm("C*(a1+b1+a2+b2+r)^0.8", CustomGBForce::ParticlePair);
force->addEnergyTerm("(D-B)*b", CustomGBForce::SingleParticle);
for (int i = 0; i < numParameters; i++) {
force->addGlobalParameter(paramNames[i], paramValues[i]);
force->addEnergyParameterDerivative(paramNames[i]);
}
force->setNonbondedMethod(CustomGBForce::CutoffPeriodic);
force->setCutoffDistance(1.0);
vector<Vec3> positions;
vector<double> parameters;
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
force->addParticle(parameters);
positions.push_back(Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt))*boxSize);
}
// Compute the energy derivative and compare it to a finite difference approximation.
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
map<string, double> derivs = context.getState(State::ParameterDerivatives).getEnergyParameterDerivatives();
for (int i = 0; i < numParameters; i++) {
context.setParameter(paramNames[i], paramValues[i]+delta);
double energy1 = context.getState(State::Energy).getPotentialEnergy();
context.setParameter(paramNames[i], paramValues[i]-delta);
double energy2 = context.getState(State::Energy).getPotentialEnergy();
ASSERT_EQUAL_TOL((energy1-energy2)/(2*delta), derivs[paramNames[i]], 5e-3);
}
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -502,6 +550,7 @@ int main(int argc, char* argv[]) { ...@@ -502,6 +550,7 @@ int main(int argc, char* argv[]) {
testPositionDependence(); testPositionDependence();
testExclusions(); testExclusions();
testIllegalVariable(); testIllegalVariable();
testEnergyParameterDerivatives();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment