Commit 2a52e208 authored by peastman's avatar peastman
Browse files

Reference implementation of parameter derivatives for CustomIntegrator

parent 74efa95f
......@@ -48,7 +48,7 @@ public:
virtual ~CustomFunction() {
}
/**
* Get the number of arguments this function exprects.
* Get the number of arguments this function expects.
*/
virtual int getNumArguments() const = 0;
/**
......
......@@ -109,7 +109,7 @@ ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const Ex
for (int i = 0; i < (int) children.size(); i++)
children[i] = precalculateConstantSubexpressions(node.getChildren()[i]);
ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children);
if (node.getOperation().getId() == Operation::VARIABLE)
if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM)
return result;
for (int i = 0; i < (int) children.size(); i++)
if (children[i].getOperation().getId() != Operation::CONSTANT)
......
......@@ -201,6 +201,16 @@ namespace OpenMM {
* only involve global variables, not per-DOF ones. It may use any of the
* following comparison operators: =, <. >, !=, <=, >=. Blocks may be nested
* inside each other.
*
* Another feature of CustomIntegrator is that it can use derivatives of the
* potential energy with respect to context parameters. These derivatives are
* typically computed by custom forces, and are only computed if a Force object
* has been specifically told to compute them by calling addEnergyParameterDerivative()
* on it. CustomIntegrator provides a deriv() function for accessing these
* derivatives in global or per-DOF expressions. For example, "deriv(energy, lambda)"
* is the derivative of the total potentially energy with respect to the parameter
* lambda. You can also restrict it to a single force group by specifying a different
* variable for the first argument, such as "deriv(energy1, lambda)".
*
* An Integrator has one other job in addition to evolving the equations of motion:
* it defines how to compute the kinetic energy of the system. Depending on the
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2015 Stanford University and the Authors. *
* Portions copyright (c) 2015-2016 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -34,8 +34,10 @@
#include "openmm/CustomIntegrator.h"
#include "openmm/internal/ContextImpl.h"
#include "lepton/CustomFunction.h"
#include "lepton/ParsedExpression.h"
#include <map>
#include <string>
#include <vector>
namespace OpenMM {
......@@ -48,6 +50,7 @@ class System;
class OPENMM_EXPORT CustomIntegratorUtilities {
public:
class DerivFunction;
enum Comparison {
EQUAL = 0, LESS_THAN = 1, GREATER_THAN = 2, NOT_EQUAL = 3, LESS_THAN_OR_EQUAL = 4, GREATER_THAN_OR_EQUAL = 5
};
......@@ -82,6 +85,28 @@ private:
const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth);
static void analyzeForceComputationsForPath(std::vector<int>& steps, const std::vector<bool>& needsForces, const std::vector<bool>& needsEnergy,
const std::vector<bool>& invalidatesForces, const std::vector<int>& forceGroup, std::vector<bool>& computeBoth);
static void validateDerivatives(const Lepton::ExpressionTreeNode& node, const std::vector<std::string>& derivNames);
};
/**
* This class is used to implement the deriv() function when it appears in expressions.
*/
class CustomIntegratorUtilities::DerivFunction : public Lepton::CustomFunction {
public:
DerivFunction() {
}
int getNumArguments() const {
return 2;
}
double evaluate(const double* arguments) const {
return 0.0;
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* clone() const {
return new DerivFunction();
}
};
} // namespace OpenMM
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2015 Stanford University and the Authors. *
* Portions copyright (c) 2015-2016 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -34,6 +34,7 @@
#include "openmm/internal/ForceImpl.h"
#include "lepton/Operation.h"
#include "lepton/Parser.h"
#include <algorithm>
#include <set>
#include <sstream>
......@@ -81,6 +82,9 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
forceGroup.resize(numSteps, -2);
vector<CustomIntegrator::ComputationType> stepType(numSteps);
vector<string> stepVariable(numSteps);
map<string, Lepton::CustomFunction*> customFunctions;
DerivFunction derivFunction;
customFunctions["deriv"] = &derivFunction;
// Parse the expressions.
......@@ -92,11 +96,11 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
string lhs, rhs;
parseCondition(expression, lhs, rhs, comparisons[step]);
expressions[step].push_back(Lepton::Parser::parse(lhs).optimize());
expressions[step].push_back(Lepton::Parser::parse(rhs).optimize());
expressions[step].push_back(Lepton::Parser::parse(lhs, customFunctions).optimize());
expressions[step].push_back(Lepton::Parser::parse(rhs, customFunctions).optimize());
}
else if (expression.size() > 0)
expressions[step].push_back(Lepton::Parser::parse(expression).optimize());
expressions[step].push_back(Lepton::Parser::parse(expression, customFunctions).optimize());
}
// Identify which steps invalidate the forces.
......@@ -191,6 +195,14 @@ void CustomIntegratorUtilities::analyzeComputations(const ContextImpl& context,
vector<int> jumps(numSteps, -1);
vector<int> stepsInPath;
enumeratePaths(0, stepsInPath, jumps, blockEnd, stepType, needsForces, needsEnergy, invalidatesForces, forceGroup, computeBoth);
// Make sure calls to deriv() all valid.
vector<string> derivNames = energyGroupName;
derivNames.push_back("energy");
for (int i = 0; i < expressions.size(); i++)
for (int j = 0; j < expressions[i].size(); j++)
validateDerivatives(expressions[i][j].getRootNode(), derivNames);
}
void CustomIntegratorUtilities::enumeratePaths(int firstStep, vector<int> steps, vector<int> jumps, const vector<int>& blockEnd,
......@@ -264,4 +276,19 @@ void CustomIntegratorUtilities::analyzeForceComputationsForPath(vector<int>& ste
currentGroup = forceGroup[step];
}
}
}
\ No newline at end of file
}
void CustomIntegratorUtilities::validateDerivatives(const Lepton::ExpressionTreeNode& node, const vector<string>& derivNames) {
const Lepton::Operation& op = node.getOperation();
if (op.getId() == Lepton::Operation::CUSTOM && op.getName() == "deriv") {
const Lepton::Operation& child = node.getChildren()[0].getOperation();
if (child.getId() != Lepton::Operation::VARIABLE || find(derivNames.begin(), derivNames.end(), child.getName()) == derivNames.end())
throw OpenMMException("The first argument to deriv() must be an energy variable");
if (node.getChildren()[1].getOperation().getId() != Lepton::Operation::VARIABLE)
throw OpenMMException("The second argument to deriv() must be a context parameter");
}
else {
for (int i = 0; i < node.getChildren().size(); i++)
validateDerivatives(node.getChildren()[i], derivNames);
}
}
......@@ -41,6 +41,7 @@ namespace OpenMM {
class ReferenceCustomDynamics : public ReferenceDynamics {
private:
class DerivFunction;
const OpenMM::CustomIntegrator& integrator;
std::vector<RealOpenMM> inverseMasses;
std::vector<OpenMM::RealVec> sumBuffer, oldPos;
......@@ -51,6 +52,7 @@ private:
std::vector<bool> invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy;
std::vector<int> forceGroupFlags, blockEnd;
RealOpenMM energy;
std::map<std::string, double> energyParamDerivs;
Lepton::CompiledExpression kineticEnergyExpression;
bool kineticEnergyNeedsForce;
CompiledExpressionSet expressionSet;
......@@ -59,6 +61,8 @@ private:
void initialize(OpenMM::ContextImpl& context, std::vector<RealOpenMM>& masses, std::map<std::string, RealOpenMM>& globals);
Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context);
void computePerDof(int numberOfAtoms, std::vector<OpenMM::RealVec>& results, const std::vector<OpenMM::RealVec>& atomCoordinates,
const std::vector<OpenMM::RealVec>& velocities, const std::vector<OpenMM::RealVec>& forces, const std::vector<RealOpenMM>& masses,
const std::vector<std::vector<OpenMM::RealVec> >& perDof, const Lepton::CompiledExpression& expression, int forceIndex);
......
......@@ -36,6 +36,28 @@
using namespace std;
using namespace OpenMM;
using namespace Lepton;
class ReferenceCustomDynamics::DerivFunction : public CustomFunction {
public:
DerivFunction(map<string, double>& energyParamDerivs, const string& param) : energyParamDerivs(energyParamDerivs), param(param) {
}
int getNumArguments() const {
return 0;
}
double evaluate(const double* arguments) const {
return energyParamDerivs[param];
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0;
}
CustomFunction* clone() const {
return new DerivFunction(energyParamDerivs, param);
}
private:
map<string, double>& energyParamDerivs;
string param;
};
/**---------------------------------------------------------------------------------------
......@@ -56,7 +78,7 @@ ReferenceCustomDynamics::ReferenceCustomDynamics(int numberOfAtoms, const Custom
string expression;
integrator.getComputationStep(i, stepType[i], stepVariable[i], expression);
}
kineticEnergyExpression = Lepton::Parser::parse(integrator.getKineticEnergyExpression()).optimize().createCompiledExpression();
kineticEnergyExpression = Parser::parse(integrator.getKineticEnergyExpression()).optimize().createCompiledExpression();
expressionSet.registerExpression(kineticEnergyExpression);
kineticEnergyNeedsForce = false;
if (kineticEnergyExpression.getVariables().find("f") != kineticEnergyExpression.getVariables().end())
......@@ -78,13 +100,13 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<RealOpenMM
int numSteps = stepType.size();
vector<int> forceGroup;
vector<vector<Lepton::ParsedExpression> > expressions;
vector<vector<ParsedExpression> > expressions;
CustomIntegratorUtilities::analyzeComputations(context, integrator, expressions, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup);
stepExpressions.resize(expressions.size());
for (int i = 0; i < numSteps; i++) {
stepExpressions[i].resize(expressions[i].size());
for (int j = 0; j < (int) expressions[i].size(); j++) {
stepExpressions[i][j] = expressions[i][j].createCompiledExpression();
stepExpressions[i][j] = ParsedExpression(replaceDerivFunctions(expressions[i][j].getRootNode(), context)).createCompiledExpression();
expressionSet.registerExpression(stepExpressions[i][j]);
}
if (stepType[i] == CustomIntegrator::WhileBlockStart)
......@@ -141,6 +163,22 @@ void ReferenceCustomDynamics::initialize(ContextImpl& context, vector<RealOpenMM
stepVariableIndex.push_back(expressionSet.getVariableIndex(stepVariable[i]));
}
ExpressionTreeNode ReferenceCustomDynamics::replaceDerivFunctions(const ExpressionTreeNode& node, ContextImpl& context) {
const Operation& op = node.getOperation();
if (op.getId() == Operation::CUSTOM && op.getName() == "deriv") {
string param = node.getChildren()[1].getOperation().getName();
if (context.getParameters().find(param) == context.getParameters().end())
throw OpenMMException("The second argument to deriv() must be a context parameter");
return ExpressionTreeNode(new Operation::Custom("deriv", new DerivFunction(energyParamDerivs, param)));
}
else {
vector<ExpressionTreeNode> children;
for (int i = 0; i < (int) node.getChildren().size(); i++)
children.push_back(replaceDerivFunctions(node.getChildren()[i], context));
return ExpressionTreeNode(op.clone(), children);
}
}
/**---------------------------------------------------------------------------------------
Update -- driver routine for performing Custom dynamics update of coordinates
......@@ -178,8 +216,10 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
bool computeEnergy = needsEnergy[step] || computeBothForceAndEnergy[step];
recordChangedParameters(context, globals);
RealOpenMM e = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]);
if (computeEnergy)
if (computeEnergy) {
energy = e;
context.getEnergyParameterDerivatives(energyParamDerivs);
}
forcesAreValid = true;
}
expressionSet.setVariable(energyVariableIndex[step], energy);
......@@ -266,7 +306,7 @@ void ReferenceCustomDynamics::update(ContextImpl& context, int numberOfAtoms, ve
void ReferenceCustomDynamics::computePerDof(int numberOfAtoms, vector<RealVec>& results, const vector<RealVec>& atomCoordinates,
const vector<RealVec>& velocities, const vector<RealVec>& forces, const vector<RealOpenMM>& masses,
const vector<vector<RealVec> >& perDof, const Lepton::CompiledExpression& expression, int forceIndex) {
const vector<vector<RealVec> >& perDof, const CompiledExpression& expression, int forceIndex) {
// Loop over all degrees of freedom.
for (int i = 0; i < numberOfAtoms; i++) {
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -29,15 +29,21 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#ifdef WIN32
#define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/Context.h"
#include "openmm/AndersenThermostat.h"
#include "openmm/CustomAngleForce.h"
#include "openmm/CustomBondForce.h"
#include "openmm/CustomIntegrator.h"
#include "openmm/HarmonicBondForce.h"
#include "openmm/NonbondedForce.h"
#include "openmm/System.h"
#include "openmm/CustomIntegrator.h"
#include "SimTKOpenMMRealType.h"
#include "sfmt/SFMT.h"
#include <cmath>
#include <iostream>
#include <vector>
......@@ -770,6 +776,84 @@ void testChangingGlobal() {
}
}
/**
* Test steps that depend on derivatives of the energy with respect to parameters.
*/
void testEnergyParameterDerivatives() {
System system;
for (int i = 0; i < 3; i++)
system.addParticle(1.0);
// Create some custom forces that depend on parameters.
CustomBondForce* bonds = new CustomBondForce("K*(A*r-r0)^2");
system.addForce(bonds);
bonds->addGlobalParameter("K", 2.0);
bonds->addGlobalParameter("A", 1.0);
bonds->addGlobalParameter("r0", 1.5);
bonds->addEnergyParameterDerivative("K");
bonds->addEnergyParameterDerivative("r0");
bonds->addBond(0, 1);
bonds->setForceGroup(0);
CustomAngleForce* angles = new CustomAngleForce("K*(B*theta-theta0)^2");
system.addForce(angles);
angles->addGlobalParameter("K", 2.0);
angles->addGlobalParameter("B", 1.0);
angles->addGlobalParameter("theta0", M_PI/3);
angles->addEnergyParameterDerivative("K");
angles->addEnergyParameterDerivative("theta0");
angles->addAngle(0, 1, 2);
angles->setForceGroup(1);
// Create an integrator that records parameter derivatives.
CustomIntegrator integrator(0.1);
integrator.addGlobalVariable("dEdK", 0.0);
integrator.addGlobalVariable("dEdr0", 0.0);
integrator.addGlobalVariable("dEdtheta0", 0.0);
integrator.addGlobalVariable("dEdK_0", 0.0);
integrator.addGlobalVariable("dEdr0_0", 0.0);
integrator.addGlobalVariable("dEdtheta0_0", 0.0);
integrator.addGlobalVariable("dEdK_1", 0.0);
integrator.addGlobalVariable("dEdr0_1", 0.0);
integrator.addGlobalVariable("dEdtheta0_1", 0.0);
integrator.addComputeGlobal("dEdK", "deriv(energy, K)");
integrator.addComputeGlobal("dEdr0", "deriv(energy, r0)");
integrator.addComputeGlobal("dEdtheta0", "deriv(energy, theta0)");
integrator.addComputeGlobal("dEdK_0", "deriv(energy0, K)");
integrator.addComputeGlobal("dEdr0_0", "deriv(energy0, r0)");
integrator.addComputeGlobal("dEdtheta0_0", "deriv(energy0, theta0)");
integrator.addComputeGlobal("dEdK_1", "deriv(energy1, K)");
integrator.addComputeGlobal("dEdr0_1", "deriv(energy1, r0)");
integrator.addComputeGlobal("dEdtheta0_1", "deriv(energy1, theta0)");
// Create a Context.
Context context(system, integrator, platform);
vector<Vec3> positions(3);
positions[0] = Vec3(0, 1, 0);
positions[1] = Vec3(0, 0, 0);
positions[2] = Vec3(1, 0, 0);
context.setPositions(positions);
// Check the results.
integrator.step(1);
double dEdK_0 = (1.0-1.5)*(1.0-1.5);
double dEdK_1 = (M_PI/2-M_PI/3)*(M_PI/2-M_PI/3);
ASSERT_EQUAL_TOL(dEdK_0, integrator.getGlobalVariableByName("dEdK_0"), 1e-5);
ASSERT_EQUAL_TOL(dEdK_1, integrator.getGlobalVariableByName("dEdK_1"), 1e-5);
ASSERT_EQUAL_TOL(dEdK_0+dEdK_1, integrator.getGlobalVariableByName("dEdK"), 1e-5);
double dEdr0 = -2.0*2.0*(1.0-1.5);
ASSERT_EQUAL_TOL(dEdr0, integrator.getGlobalVariableByName("dEdr0_0"), 1e-5);
ASSERT_EQUAL_TOL(0.0, integrator.getGlobalVariableByName("dEdr0_1"), 1e-5);
ASSERT_EQUAL_TOL(dEdr0, integrator.getGlobalVariableByName("dEdr0"), 1e-5);
double dEdtheta0 = -2.0*2.0*(M_PI/2-M_PI/3);
ASSERT_EQUAL_TOL(0.0, integrator.getGlobalVariableByName("dEdtheta0_0"), 1e-5);
ASSERT_EQUAL_TOL(dEdtheta0, integrator.getGlobalVariableByName("dEdtheta0_1"), 1e-5);
ASSERT_EQUAL_TOL(dEdtheta0, integrator.getGlobalVariableByName("dEdtheta0"), 1e-5);
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -790,6 +874,7 @@ int main(int argc, char* argv[]) {
testIfBlock();
testWhileBlock();
testChangingGlobal();
testEnergyParameterDerivatives();
runPlatformTests();
}
catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment