Commit 2a52e208 authored by peastman's avatar peastman
Browse files

Reference implementation of parameter derivatives for CustomIntegrator

parent 74efa95f
...@@ -48,7 +48,7 @@ public: ...@@ -48,7 +48,7 @@ public:
virtual ~CustomFunction() { virtual ~CustomFunction() {
} }
/** /**
* Get the number of arguments this function exprects. * Get the number of arguments this function expects.
*/ */
virtual int getNumArguments() const = 0; virtual int getNumArguments() const = 0;
/** /**
......
...@@ -109,7 +109,7 @@ ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const Ex ...@@ -109,7 +109,7 @@ ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const Ex
for (int i = 0; i < (int) children.size(); i++) for (int i = 0; i < (int) children.size(); i++)
children[i] = precalculateConstantSubexpressions(node.getChildren()[i]); children[i] = precalculateConstantSubexpressions(node.getChildren()[i]);
ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children); ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children);
if (node.getOperation().getId() == Operation::VARIABLE) if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM)
return result; return result;
for (int i = 0; i < (int) children.size(); i++) for (int i = 0; i < (int) children.size(); i++)
if (children[i].getOperation().getId() != Operation::CONSTANT) if (children[i].getOperation().getId() != Operation::CONSTANT)
......
...@@ -202,6 +202,16 @@ namespace OpenMM { ...@@ -202,6 +202,16 @@ namespace OpenMM {
* following comparison operators: =, <. >, !=, <=, >=. Blocks may be nested * following comparison operators: =, <. >, !=, <=, >=. Blocks may be nested
* inside each other. * inside each other.
* *
* Another feature of CustomIntegrator is that it can use derivatives of the
* potential energy with respect to context parameters. These derivatives are
* typically computed by custom forces, and are only computed if a Force object
* has been specifically told to compute them by calling addEnergyParameterDerivative()
* on it. CustomIntegrator provides a deriv() function for accessing these
* derivatives in global or per-DOF expressions. For example, "deriv(energy, lambda)"
* is the derivative of the total potentially energy with respect to the parameter
* lambda. You can also restrict it to a single force group by specifying a different
* variable for the first argument, such as "deriv(energy1, lambda)".
*
* An Integrator has one other job in addition to evolving the equations of motion: * An Integrator has one other job in addition to evolving the equations of motion:
* it defines how to compute the kinetic energy of the system. Depending on the * it defines how to compute the kinetic energy of the system. Depending on the
* integration method used, simply summing mv<sup>2</sup>/2 over all degrees of * integration method used, simply summing mv<sup>2</sup>/2 over all degrees of
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2015 Stanford University and the Authors. * * Portions copyright (c) 2015-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -34,8 +34,10 @@ ...@@ -34,8 +34,10 @@
#include "openmm/CustomIntegrator.h" #include "openmm/CustomIntegrator.h"
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "lepton/CustomFunction.h"
#include "lepton/ParsedExpression.h" #include "lepton/ParsedExpression.h"
#include <map> #include <map>
#include <string>
#include <vector> #include <vector>
namespace OpenMM { namespace OpenMM {
...@@ -48,6 +50,7 @@ class System; ...@@ -48,6 +50,7 @@ class System;
class OPENMM_EXPORT CustomIntegratorUtilities { class OPENMM_EXPORT CustomIntegratorUtilities {
public: public:
class DerivFunction;
enum Comparison { enum Comparison {
EQUAL = 0, LESS_THAN = 1, GREATER_THAN = 2, NOT_EQUAL = 3, LESS_THAN_OR_EQUAL = 4, GREATER_THAN_OR_EQUAL = 5 EQUAL = 0, LESS_THAN = 1, GREATER_THAN = 2, NOT_EQUAL = 3, LESS_THAN_OR_EQUAL = 4, GREATER_THAN_OR_EQUAL = 5
}; };
...@@ -82,6 +85,28 @@ private: ...@@ -82,6 +85,28 @@ private:
const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth); const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth);
static void analyzeForceComputationsForPath(std::vector<int>& steps, const std::vector<bool>& needsForces, const std::vector<bool>& needsEnergy, static void analyzeForceComputationsForPath(std::vector<int>& steps, const std::vector<bool>& needsForces, const std::vector<bool>& needsEnergy,
const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth); const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth);
static void validateDerivatives(const Lepton::ExpressionTreeNode& node, const std::vector<std::string>& derivNames);
};
/**
* This class is used to implement the deriv() function when it appears in expressions.
*/
class CustomIntegratorUtilities::DerivFunction : public Lepton::CustomFunction {
public:
DerivFunction() {
}
int getNumArguments() const {
return 2;
}
double evaluate(const double* arguments) const {
return 0.0;
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* clone() const {
return new DerivFunction();
}
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2015 Stanford University and the Authors. * * Portions copyright (c) 2015-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "openmm/internal/ForceImpl.h" #include "openmm/internal/ForceImpl.h"
#include "lepton/Operation.h" #include "lepton/Operation.h"
#include "lepton/Parser.h" #include "lepton/Parser.h"
#include <algorithm>
#include <set> #include <set>
#include <sstream> #include <sstream>
...@@ -81,6 +82,9 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, ...@@ -81,6 +82,9 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
forceGroup.resize(numSteps, -2); forceGroup.resize(numSteps, -2);
vector<CustomIntegrator::ComputationType> stepType(numSteps); vector<CustomIntegrator::ComputationType> stepType(numSteps);
vector<string> stepVariable(numSteps); vector<string> stepVariable(numSteps);
map<string, Lepton::CustomFunction*> customFunctions;
DerivFunction derivFunction;
customFunctions["deriv"] = &derivFunction;
// Parse the expressions. // Parse the expressions.
...@@ -92,11 +96,11 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, ...@@ -92,11 +96,11 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
string lhs, rhs; string lhs, rhs;
parseCondition(expression, lhs, rhs, comparisons[step]); parseCondition(expression, lhs, rhs, comparisons[step]);
expressions[step].push_back(Lepton::Parser::parse(lhs).optimize()); expressions[step].push_back(Lepton::Parser::parse(lhs, customFunctions).optimize());
expressions[step].push_back(Lepton::Parser::parse(rhs).optimize()); expressions[step].push_back(Lepton::Parser::parse(rhs, customFunctions).optimize());
} }
else if (expression.size() > 0) else if (expression.size() > 0)
expressions[step].push_back(Lepton::Parser::parse(expression).optimize()); expressions[step].push_back(Lepton::Parser::parse(expression, customFunctions).optimize());
} }
// Identify which steps invalidate the forces. // Identify which steps invalidate the forces.
...@@ -191,6 +195,14 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context, ...@@ -191,6 +195,14 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
vector<int> jumps(numSteps, -1); vector<int> jumps(numSteps, -1);
vector<int> stepsInPath; vector<int> stepsInPath;
enumeratePaths(0, stepsInPath, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth); enumeratePaths(0, stepsInPath, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth);
// Make sure calls to deriv() all valid.
vector<string> derivNames = energyGroupName;
derivNames.push_back("energy");
for (int i = 0; i < expressions.size(); i++)
for (int j = 0; j < expressions[i].size(); j++)
validateDerivatives(expressions[i][j].getRootNode(), derivNames);
} }
void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps, vector<int> jumps, const vector<int>& blockEnd, void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps, vector<int> jumps, const vector<int>& blockEnd,
...@@ -265,3 +277,18 @@ void CustomIntegratorUtilities::analyzeForceComputationsForPath(vector<int>& ste ...@@ -265,3 +277,18 @@ void CustomIntegratorUtilities::analyzeForceComputationsForPath(vector<int>& ste
} }
} }
} }
void CustomIntegratorUtilities::validateDerivatives(const Lepton::ExpressionTreeNode& node, const vector<string>& derivNames) {
const Lepton::Operation& op = node.getOperation();
if (op.getId() == Lepton::Operation::CUSTOM && op.getName() == "deriv") {
const Lepton::Operation& child = node.getChildren()[0].getOperation();
if (child.getId() != Lepton::Operation::VARIABLE || find(derivNames.begin(), derivNames.end(), child.getName()) == derivNames.end())
throw OpenMMException("The first argument to deriv() must be an energy variable");
if (node.getChildren()[1].getOperation().getId() != Lepton::Operation::VARIABLE)
throw OpenMMException("The second argument to deriv() must be a context parameter");
}
else {
for (int i = 0; i < node.getChildren().size(); i++)
validateDerivatives(node.getChildren()[i], derivNames);
}
}
...@@ -41,6 +41,7 @@ namespace OpenMM { ...@@ -41,6 +41,7 @@ namespace OpenMM {
class ReferenceCustomDynamics : public ReferenceDynamics { class ReferenceCustomDynamics : public ReferenceDynamics {
private: private:
class DerivFunction;
const OpenMM::CustomIntegrator& integrator; const OpenMM::CustomIntegrator& integrator;
std::vector<RealOpenMM> inverseMasses; std::vector<RealOpenMM> inverseMasses;
std::vector<OpenMM::RealVec> sumBuffer, oldPos; std::vector<OpenMM::RealVec> sumBuffer, oldPos;
...@@ -51,6 +52,7 @@ private: ...@@ -51,6 +52,7 @@ private:
std::vector<bool> invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy; std::vector<bool> invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy;
std::vector<int> forceGroupFlags, blockEnd; std::vector<int> forceGroupFlags, blockEnd;
RealOpenMM energy; RealOpenMM energy;
std::map<std::string, double> energyParamDerivs;
Lepton::CompiledExpression kineticEnergyExpression; Lepton::CompiledExpression kineticEnergyExpression;
bool kineticEnergyNeedsForce; bool kineticEnergyNeedsForce;
CompiledExpressionSet expressionSet; CompiledExpressionSet expressionSet;
...@@ -59,6 +61,8 @@ private: ...@@ -59,6 +61,8 @@ private:
void initialize(OpenMM::ContextImpl& context, std::vector<RealOpenMM>& masses, std::map<std::string, RealOpenMM>& globals); void initialize(OpenMM::ContextImpl& context, std::vector<RealOpenMM>& masses, std::map<std::string, RealOpenMM>& globals);
Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context);
void computePerDof(int numberOfAtoms, std::vector<OpenMM::RealVec>& results, const std::vector<OpenMM::RealVec>& atomCoordinates, void computePerDof(int numberOfAtoms, std::vector<OpenMM::RealVec>& results, const std::vector<OpenMM::RealVec>& atomCoordinates,
const std::vector<OpenMM::RealVec>& velocities, const std::vector<OpenMM::RealVec>& forces, const std::vector<RealOpenMM>& masses, const std::vector<OpenMM::RealVec>& velocities, const std::vector<OpenMM::RealVec>& forces, const std::vector<RealOpenMM>& masses,
const std::vector<std::vector<OpenMM::RealVec> >& perDof, const Lepton::CompiledExpression& expression, int forceIndex); const std::vector<std::vector<OpenMM::RealVec> >& perDof, const Lepton::CompiledExpression& expression, int forceIndex);
......
...@@ -36,6 +36,28 @@ ...@@ -36,6 +36,28 @@
using namespace std; using namespace std;
using namespace OpenMM; using namespace OpenMM;
using namespace Lepton;
class ReferenceCustomDynamics::DerivFunction : public CustomFunction {
public:
DerivFunction(map<string, double>& energyParamDerivs, const string& param) : energyParamDerivs(energyParamDerivs), param(param) {
}
int getNumArguments() const {
return 0;
}
double evaluate(const double* arguments) const {
return energyParamDerivs[param];
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0;
}
CustomFunction* clone() const {
return new DerivFunction(energyParamDerivs, param);
}
private:
map<string, double>& energyParamDerivs;
string param;
};
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -56,7 +78,7 @@ ReferenceCustomDynamics::ReferenceCustomDynamics(int numberOfAtoms, const Custom ...@@ -56,7 +78,7 @@ ReferenceCustomDynamics::ReferenceCustomDynamics(int numberOfAtoms, const Custom
string expression; string expression;
integrator.getComputationStep(i, stepType[i], stepVariable[i], expression); integrator.getComputationStep(i, stepType[i], stepVariable[i], expression);
} }
kineticEnergyExpression = Lepton::Parser::parse(integrator.getKineticEnergyExpression()).optimize().createCompiledExpression(); kineticEnergyExpression = Parser::parse(integrator.getKineticEnergyExpression()).optimize().createCompiledExpression();
expressionSet.registerExpression(kineticEnergyExpression); expressionSet.registerExpression(kineticEnergyExpression);
kineticEnergyNeedsForce = false; kineticEnergyNeedsForce = false;
if (kineticEnergyExpression.getVariables().find("f") != kineticEnergyExpression.getVariables().end()) if (kineticEnergyExpression.getVariables().find("f") != kineticEnergyExpression.getVariables().end())
...@@ -78,13 +100,13 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<RealOpenMM ...@@ -78,13 +100,13 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<RealOpenMM
int numSteps = stepType.size(); int numSteps = stepType.size();
vector<int> forceGroup; vector<int> forceGroup;
vector<vector<Lepton::ParsedExpression> > expressions; vector<vector<ParsedExpression> > expressions;
CustomIntegratorUtilities::analyzeComputations(context, integrator, expressions, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup); CustomIntegratorUtilities::analyzeComputations(context, integrator, expressions, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup);
stepExpressions.resize(expressions.size()); stepExpressions.resize(expressions.size());
for (int i = 0; i < numSteps; i++) { for (int i = 0; i < numSteps; i++) {
stepExpressions[i].resize(expressions[i].size()); stepExpressions[i].resize(expressions[i].size());
for (int j = 0; j < (int) expressions[i].size(); j++) { for (int j = 0; j < (int) expressions[i].size(); j++) {
stepExpressions[i][j] = expressions[i][j].createCompiledExpression(); stepExpressions[i][j] = ParsedExpression(replaceDerivFunctions(expressions[i][j].getRootNode(), context)).createCompiledExpression();
expressionSet.registerExpression(stepExpressions[i][j]); expressionSet.registerExpression(stepExpressions[i][j]);
} }
if (stepType[i] == CustomIntegrator::WhileBlockStart) if (stepType[i] == CustomIntegrator::WhileBlockStart)
...@@ -141,6 +163,22 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<RealOpenMM ...@@ -141,6 +163,22 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<RealOpenMM
stepVariableIndex.push_back(expressionSet.getVariableIndex(stepVariable[i])); stepVariableIndex.push_back(expressionSet.getVariableIndex(stepVariable[i]));
} }
ExpressionTreeNode ReferenceCustomDynamics::replaceDerivFunctions(const ExpressionTreeNode& node, ContextImpl& context) {
const Operation& op = node.getOperation();
if (op.getId() == Operation::CUSTOM && op.getName() == "deriv") {
string param = node.getChildren()[1].getOperation().getName();
if (context.getParameters().find(param) == context.getParameters().end())
throw OpenMMException("The second argument to deriv() must be a context parameter");
return ExpressionTreeNode(new Operation::Custom("deriv", new DerivFunction(energyParamDerivs, param)));
}
else {
vector<ExpressionTreeNode> children;
for (int i = 0; i < (int) node.getChildren().size(); i++)
children.push_back(replaceDerivFunctions(node.getChildren()[i], context));
return ExpressionTreeNode(op.clone(), children);
}
}
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
Update -- driver routine for performing Custom dynamics update of coordinates Update -- driver routine for performing Custom dynamics update of coordinates
...@@ -178,8 +216,10 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -178,8 +216,10 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
bool computeEnergy = needsEnergy[step] || computeBothForceAndEnergy[step]; bool computeEnergy = needsEnergy[step] || computeBothForceAndEnergy[step];
recordChangedParameters(context, globals); recordChangedParameters(context, globals);
RealOpenMM e = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]); RealOpenMM e = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]);
if (computeEnergy) if (computeEnergy) {
energy = e; energy = e;
context.getEnergyParameterDerivatives(energyParamDerivs);
}
forcesAreValid = true; forcesAreValid = true;
} }
expressionSet.setVariable(energyVariableIndex[step], energy); expressionSet.setVariable(energyVariableIndex[step], energy);
...@@ -266,7 +306,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve ...@@ -266,7 +306,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
void ReferenceCustomDynamics::computePerDof(int numberOfAtoms, vector<RealVec>& results, const vector<RealVec>& atomCoordinates, void ReferenceCustomDynamics::computePerDof(int numberOfAtoms, vector<RealVec>& results, const vector<RealVec>& atomCoordinates,
const vector<RealVec>& velocities, const vector<RealVec>& forces, const vector<RealOpenMM>& masses, const vector<RealVec>& velocities, const vector<RealVec>& forces, const vector<RealOpenMM>& masses,
const vector<vector<RealVec> >& perDof, const Lepton::CompiledExpression& expression, int forceIndex) { const vector<vector<RealVec> >& perDof, const CompiledExpression& expression, int forceIndex) {
// Loop over all degrees of freedom. // Loop over all degrees of freedom.
for (int i = 0; i < numberOfAtoms; i++) { for (int i = 0; i < numberOfAtoms; i++) {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. * * Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -29,15 +29,21 @@ ...@@ -29,15 +29,21 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. * * USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#ifdef WIN32
#define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include "openmm/internal/AssertionUtilities.h" #include "openmm/internal/AssertionUtilities.h"
#include "openmm/Context.h" #include "openmm/Context.h"
#include "openmm/AndersenThermostat.h" #include "openmm/AndersenThermostat.h"
#include "openmm/CustomAngleForce.h"
#include "openmm/CustomBondForce.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/HarmonicBondForce.h" #include "openmm/HarmonicBondForce.h"
#include "openmm/NonbondedForce.h" #include "openmm/NonbondedForce.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/CustomIntegrator.h"
#include "SimTKOpenMMRealType.h" #include "SimTKOpenMMRealType.h"
#include "sfmt/SFMT.h" #include "sfmt/SFMT.h"
#include <cmath>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -770,6 +776,84 @@ void testChangingGlobal() { ...@@ -770,6 +776,84 @@ void testChangingGlobal() {
} }
} }
/**
* Test steps that depend on derivatives of the energy with respect to parameters.
*/
void testEnergyParameterDerivatives() {
System system;
for (int i = 0; i < 3; i++)
system.addParticle(1.0);
// Create some custom forces that depend on parameters.
CustomBondForce* bonds = new CustomBondForce("K*(A*r-r0)^2");
system.addForce(bonds);
bonds->addGlobalParameter("K", 2.0);
bonds->addGlobalParameter("A", 1.0);
bonds->addGlobalParameter("r0", 1.5);
bonds->addEnergyParameterDerivative("K");
bonds->addEnergyParameterDerivative("r0");
bonds->addBond(0, 1);
bonds->setForceGroup(0);
CustomAngleForce* angles = new CustomAngleForce("K*(B*theta-theta0)^2");
system.addForce(angles);
angles->addGlobalParameter("K", 2.0);
angles->addGlobalParameter("B", 1.0);
angles->addGlobalParameter("theta0", M_PI/3);
angles->addEnergyParameterDerivative("K");
angles->addEnergyParameterDerivative("theta0");
angles->addAngle(0, 1, 2);
angles->setForceGroup(1);
// Create an integrator that records parameter derivatives.
CustomIntegrator integrator(0.1);
integrator.addGlobalVariable("dEdK", 0.0);
integrator.addGlobalVariable("dEdr0", 0.0);
integrator.addGlobalVariable("dEdtheta0", 0.0);
integrator.addGlobalVariable("dEdK_0", 0.0);
integrator.addGlobalVariable("dEdr0_0", 0.0);
integrator.addGlobalVariable("dEdtheta0_0", 0.0);
integrator.addGlobalVariable("dEdK_1", 0.0);
integrator.addGlobalVariable("dEdr0_1", 0.0);
integrator.addGlobalVariable("dEdtheta0_1", 0.0);
integrator.addComputeGlobal("dEdK", "deriv(energy, K)");
integrator.addComputeGlobal("dEdr0", "deriv(energy, r0)");
integrator.addComputeGlobal("dEdtheta0", "deriv(energy, theta0)");
integrator.addComputeGlobal("dEdK_0", "deriv(energy0, K)");
integrator.addComputeGlobal("dEdr0_0", "deriv(energy0, r0)");
integrator.addComputeGlobal("dEdtheta0_0", "deriv(energy0, theta0)");
integrator.addComputeGlobal("dEdK_1", "deriv(energy1, K)");
integrator.addComputeGlobal("dEdr0_1", "deriv(energy1, r0)");
integrator.addComputeGlobal("dEdtheta0_1", "deriv(energy1, theta0)");
// Create a Context.
Context context(system, integrator, platform);
vector<Vec3> positions(3);
positions[0] = Vec3(0, 1, 0);
positions[1] = Vec3(0, 0, 0);
positions[2] = Vec3(1, 0, 0);
context.setPositions(positions);
// Check the results.
integrator.step(1);
double dEdK_0 = (1.0-1.5)*(1.0-1.5);
double dEdK_1 = (M_PI/2-M_PI/3)*(M_PI/2-M_PI/3);
ASSERT_EQUAL_TOL(dEdK_0, integrator.getGlobalVariableByName("dEdK_0"), 1e-5);
ASSERT_EQUAL_TOL(dEdK_1, integrator.getGlobalVariableByName("dEdK_1"), 1e-5);
ASSERT_EQUAL_TOL(dEdK_0+dEdK_1, integrator.getGlobalVariableByName("dEdK"), 1e-5);
double dEdr0 = -2.0*2.0*(1.0-1.5);
ASSERT_EQUAL_TOL(dEdr0, integrator.getGlobalVariableByName("dEdr0_0"), 1e-5);
ASSERT_EQUAL_TOL(0.0, integrator.getGlobalVariableByName("dEdr0_1"), 1e-5);
ASSERT_EQUAL_TOL(dEdr0, integrator.getGlobalVariableByName("dEdr0"), 1e-5);
double dEdtheta0 = -2.0*2.0*(M_PI/2-M_PI/3);
ASSERT_EQUAL_TOL(0.0, integrator.getGlobalVariableByName("dEdtheta0_0"), 1e-5);
ASSERT_EQUAL_TOL(dEdtheta0, integrator.getGlobalVariableByName("dEdtheta0_1"), 1e-5);
ASSERT_EQUAL_TOL(dEdtheta0, integrator.getGlobalVariableByName("dEdtheta0"), 1e-5);
}
void runPlatformTests(); void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -790,6 +874,7 @@ int main(int argc, char* argv[]) { ...@@ -790,6 +874,7 @@ int main(int argc, char* argv[]) {
testIfBlock(); testIfBlock();
testWhileBlock(); testWhileBlock();
testChangingGlobal(); testChangingGlobal();
testEnergyParameterDerivatives();
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment