Commit cd566c63 authored by peastman's avatar peastman
Browse files

Beginnings of support for derivatives with respect to parameters

parent 77b9b7ba
......@@ -157,6 +157,12 @@ public:
* @param forces on exit, this contains the forces
*/
void getForces(ContextImpl& context, std::vector<Vec3>& forces);
/**
* Get the current derivatives of the energy with respect to context parameters.
*
* @param derivs on exit, this contains the derivatives
*/
void getEnergyParameterDerivatives(ContextImpl& context, std::map<std::string, double>& derivs);
/**
* Get the current periodic box vectors.
*
......@@ -319,7 +325,8 @@ private:
int **bondIndexArray;
RealOpenMM **bondParamArray;
Lepton::CompiledExpression energyExpression, forceExpression;
std::vector<std::string> parameterNames, globalParameterNames;
std::vector<Lepton::CompiledExpression> energyParamDerivExpressions;
std::vector<std::string> parameterNames, globalParameterNames, energyParamDerivNames;
bool usePeriodic;
};
......
/* Portions copyright (c) 2006 Stanford University and Simbios.
/* Portions copyright (c) 2006-2016 Stanford University and Simbios.
* Contributors: Pande Group
*
* Permission is hereby granted, free of charge, to any person obtaining
......@@ -64,7 +64,7 @@ class OPENMM_EXPORT ReferenceLJCoulomb14 : public ReferenceBondIxn {
void calculateBondIxn(int* atomIndices, std::vector<OpenMM::RealVec>& atomCoordinates,
RealOpenMM* parameters, std::vector<OpenMM::RealVec>& forces,
RealOpenMM* totalEnergy) const;
RealOpenMM* totalEnergy, double* energyParamDerivs);
};
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008 Stanford University and the Authors. *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -68,6 +68,7 @@ public:
void* periodicBoxSize;
void* periodicBoxVectors;
void* constraints;
void* energyParameterDerivatives;
};
} // namespace OpenMM
......
......@@ -80,7 +80,7 @@ class OPENMM_EXPORT ReferenceProperDihedralBond : public ReferenceBondIxn {
void calculateBondIxn(int* atomIndices, std::vector<OpenMM::RealVec>& atomCoordinates,
RealOpenMM* parameters, std::vector<OpenMM::RealVec>& forces,
RealOpenMM* totalEnergy) const;
RealOpenMM* totalEnergy, double* energyParamDerivs);
};
......
......@@ -78,7 +78,7 @@ class OPENMM_EXPORT ReferenceRbDihedralBond : public ReferenceBondIxn {
void calculateBondIxn(int* atomIndices, std::vector<OpenMM::RealVec>& atomCoordinates,
RealOpenMM* parameters, std::vector<OpenMM::RealVec>& forces,
RealOpenMM* totalEnergy) const;
RealOpenMM* totalEnergy, double* energyParamDerivs);
};
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -146,6 +146,11 @@ static ReferenceConstraints& extractConstraints(ContextImpl& context) {
return *(ReferenceConstraints*) data->constraints;
}
static map<string, double>& extractEnergyParameterDerivatives(ContextImpl& context) {
ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
return *((map<string, double>*) data->energyParameterDerivatives);
}
/**
* Make sure an expression doesn't use any undefined variables.
*/
......@@ -208,6 +213,8 @@ void ReferenceCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context,
}
else
savedForces = forceData;
for (map<string, double>::const_iterator iter = context.getParameters().begin(); iter != context.getParameters().end(); ++iter)
extractEnergyParameterDerivatives(context)[iter->first] = 0;
}
double ReferenceCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups, bool& valid) {
......@@ -273,6 +280,10 @@ void ReferenceUpdateStateDataKernel::getForces(ContextImpl& context, std::vector
forces[i] = Vec3(forceData[i][0], forceData[i][1], forceData[i][2]);
}
void ReferenceUpdateStateDataKernel::getEnergyParameterDerivatives(ContextImpl& context, map<string, double>& derivs) {
derivs = extractEnergyParameterDerivatives(context);
}
void ReferenceUpdateStateDataKernel::getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const {
RealVec* vectors = extractBoxVectors(context);
a = vectors[0];
......@@ -437,6 +448,11 @@ void ReferenceCalcCustomBondForceKernel::initialize(const System& system, const
parameterNames.push_back(force.getPerBondParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string param = force.getEnergyParameterDerivativeName(i);
energyParamDerivNames.push_back(param);
energyParamDerivExpressions.push_back(expression.differentiate(param).createCompiledExpression());
}
set<string> variables;
variables.insert("r");
variables.insert(parameterNames.begin(), parameterNames.end());
......@@ -451,11 +467,15 @@ double ReferenceCalcCustomBondForceKernel::execute(ContextImpl& context, bool in
map<string, double> globalParameters;
for (int i = 0; i < (int) globalParameterNames.size(); i++)
globalParameters[globalParameterNames[i]] = context.getParameter(globalParameterNames[i]);
ReferenceBondForce refBondForce;
ReferenceCustomBondIxn bond(energyExpression, forceExpression, parameterNames, globalParameters);
ReferenceCustomBondIxn bond(energyExpression, forceExpression, parameterNames, globalParameters, energyParamDerivExpressions);
if (usePeriodic)
bond.setPeriodic(extractBoxVectors(context));
refBondForce.calculateForce(numBonds, bondIndexArray, posData, bondParamArray, forceData, includeEnergy ? &energy : NULL, bond);
vector<double> energyParamDerivValues(energyParamDerivNames.size()+1, 0.0);
for (int i = 0; i < numBonds; i++)
bond.calculateBondIxn(bondIndexArray[i], posData, bondParamArray[i], forceData, includeEnergy ? &energy : NULL, &energyParamDerivValues[0]);
map<string, double>& energyParamDerivs = extractEnergyParameterDerivatives(context);
for (int i = 0; i < energyParamDerivNames.size(); i++)
energyParamDerivs[energyParamDerivNames[i]] += energyParamDerivValues[i];
return energy;
}
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008 Stanford University and the Authors. *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -36,6 +36,7 @@
#include "openmm/internal/ContextImpl.h"
#include "SimTKOpenMMRealType.h"
#include "RealVec.h"
#include <map>
#include <vector>
using namespace OpenMM;
......@@ -99,6 +100,7 @@ ReferencePlatform::PlatformData::PlatformData(const System& system) : time(0.0),
periodicBoxSize = new RealVec();
periodicBoxVectors = new RealVec[3];
constraints = new ReferenceConstraints(system);
energyParameterDerivatives = new map<string, double>();
}
ReferencePlatform::PlatformData::~PlatformData() {
......@@ -108,4 +110,5 @@ ReferencePlatform::PlatformData::~PlatformData() {
delete (RealVec*) periodicBoxSize;
delete[] (RealVec*) periodicBoxVectors;
delete (ReferenceConstraints*) constraints;
delete (map<string, double>*) energyParameterDerivatives;
}
/* Portions copyright (c) 2006 Stanford University and Simbios.
/* Portions copyright (c) 2006-2016 Stanford University and Simbios.
* Contributors: Pande Group
*
* Permission is hereby granted, free of charge, to any person obtaining
......@@ -129,7 +129,7 @@ void ReferenceAngleBondIxn::calculateBondIxn(int* atomIndices,
vector<RealVec>& atomCoordinates,
RealOpenMM* parameters,
vector<RealVec>& forces,
RealOpenMM* totalEnergy) const {
RealOpenMM* totalEnergy, double* energyParamDerivs) {
// constants -- reduce Visual Studio warnings regarding conversions between float & double
......
/* Portions copyright (c) 2006 Stanford University and Simbios.
/* Portions copyright (c) 2006-2016 Stanford University and Simbios.
* Contributors: Pande Group
*
* Permission is hereby granted, free of charge, to any person obtaining
......@@ -97,7 +97,7 @@ void ReferenceBondForce::calculateForce(int numberOfBonds, int** atomIndices,
// calculate bond ixn
referenceBondIxn.calculateBondIxn(atomIndices[ii], atomCoordinates, parameters[ii],
forces, totalEnergy);
forces, totalEnergy, NULL);
}
}
/* Portions copyright (c) 2006-2009 Stanford University and Simbios.
/* Portions copyright (c) 2006-2016 Stanford University and Simbios.
* Contributors: Pande Group
*
* Permission is hereby granted, free of charge, to any person obtaining
......@@ -78,7 +78,7 @@ ReferenceBondIxn::~ReferenceBondIxn() {
void ReferenceBondIxn::calculateBondIxn(int* atomIndices, vector<RealVec>& atomCoordinates,
RealOpenMM* parameters, vector<RealVec>& forces,
RealOpenMM* totalEnergy) const {
RealOpenMM* totalEnergy, double* energyParamDerivs) {
// ---------------------------------------------------------------------------------------
// static const std::string methodName = "\nReferenceBondIxn::calculateBondIxn";
......
......@@ -207,5 +207,5 @@ void ReferenceCMAPTorsionIxn::calculateOneIxn(int index, vector<RealVec>& atomCo
--------------------------------------------------------------------------------------- */
void ReferenceCMAPTorsionIxn::calculateBondIxn(int* atomIndices, vector<RealVec>& atomCoordinates,
RealOpenMM* parameters, vector<RealVec>& forces, RealOpenMM* totalEnergy) const {
RealOpenMM* parameters, vector<RealVec>& forces, RealOpenMM* totalEnergy, double* energyParamDerivs) {
}
......@@ -86,7 +86,7 @@ void ReferenceCustomAngleIxn::calculateBondIxn(int* atomIndices,
vector<RealVec>& atomCoordinates,
RealOpenMM* parameters,
vector<RealVec>& forces,
RealOpenMM* totalEnergy) const {
RealOpenMM* totalEnergy, double* energyParamDerivs) {
static const std::string methodName = "\nReferenceCustomAngleIxn::calculateAngleIxn";
......
......@@ -39,19 +39,19 @@ using namespace OpenMM;
--------------------------------------------------------------------------------------- */
ReferenceCustomBondIxn::ReferenceCustomBondIxn(const Lepton::CompiledExpression& energyExpression,
const Lepton::CompiledExpression& forceExpression, const vector<string>& parameterNames, map<string, double> globalParameters) :
energyExpression(energyExpression), forceExpression(forceExpression), usePeriodic(false) {
energyR = ReferenceForce::getVariablePointer(this->energyExpression, "r");
forceR = ReferenceForce::getVariablePointer(this->forceExpression, "r");
const Lepton::CompiledExpression& forceExpression, const vector<string>& parameterNames, map<string, double> globalParameters,
const vector<Lepton::CompiledExpression> energyParamDerivExpressions) :
energyExpression(energyExpression), forceExpression(forceExpression), usePeriodic(false), energyParamDerivExpressions(energyParamDerivExpressions) {
expressionSet.registerExpression(this->energyExpression);
expressionSet.registerExpression(this->forceExpression);
for (int i = 0; i < this->energyParamDerivExpressions.size(); i++)
expressionSet.registerExpression(this->energyParamDerivExpressions[i]);
rIndex = expressionSet.getVariableIndex("r");
numParameters = parameterNames.size();
for (int i = 0; i < (int) numParameters; i++) {
energyParams.push_back(ReferenceForce::getVariablePointer(this->energyExpression, parameterNames[i]));
forceParams.push_back(ReferenceForce::getVariablePointer(this->forceExpression, parameterNames[i]));
}
for (map<string, double>::const_iterator iter = globalParameters.begin(); iter != globalParameters.end(); ++iter) {
ReferenceForce::setVariable(ReferenceForce::getVariablePointer(this->energyExpression, iter->first), iter->second);
ReferenceForce::setVariable(ReferenceForce::getVariablePointer(this->forceExpression, iter->first), iter->second);
}
for (int i = 0; i < (int) numParameters; i++)
bondParamIndex.push_back(expressionSet.getVariableIndex(parameterNames[i]));
for (map<string, double>::const_iterator iter = globalParameters.begin(); iter != globalParameters.end(); ++iter)
expressionSet.setVariable(expressionSet.getVariableIndex(iter->first), iter->second);
}
/**---------------------------------------------------------------------------------------
......@@ -86,21 +86,10 @@ void ReferenceCustomBondIxn::calculateBondIxn(int* atomIndices,
vector<RealVec>& atomCoordinates,
RealOpenMM* parameters,
vector<RealVec>& forces,
RealOpenMM* totalEnergy) const {
static const std::string methodName = "\nReferenceCustomBondIxn::calculateBondIxn";
static const int twoI = 2;
static const RealOpenMM zero = 0.0;
static const RealOpenMM two = 2.0;
static const RealOpenMM half = 0.5;
RealOpenMM* totalEnergy, double* energyParamDerivs) {
RealOpenMM deltaR[ReferenceForce::LastDeltaRIndex];
for (int i = 0; i < numParameters; i++) {
ReferenceForce::setVariable(energyParams[i], parameters[i]);
ReferenceForce::setVariable(forceParams[i], parameters[i]);
}
for (int i = 0; i < numParameters; i++)
expressionSet.setVariable(bondParamIndex[i], parameters[i]);
// ---------------------------------------------------------------------------------------
......@@ -113,10 +102,9 @@ void ReferenceCustomBondIxn::calculateBondIxn(int* atomIndices,
else
ReferenceForce::getDeltaR(atomCoordinates[atomAIndex], atomCoordinates[atomBIndex], deltaR);
ReferenceForce::setVariable(energyR, deltaR[ReferenceForce::RIndex]);
ReferenceForce::setVariable(forceR, deltaR[ReferenceForce::RIndex]);
expressionSet.setVariable(rIndex, deltaR[ReferenceForce::RIndex]);
RealOpenMM dEdR = (RealOpenMM) forceExpression.evaluate();
dEdR = deltaR[ReferenceForce::RIndex] > zero ? (dEdR/deltaR[ReferenceForce::RIndex]) : zero;
dEdR = deltaR[ReferenceForce::RIndex] > 0 ? (dEdR/deltaR[ReferenceForce::RIndex]) : 0;
forces[atomAIndex][0] += dEdR*deltaR[ReferenceForce::XIndex];
forces[atomAIndex][1] += dEdR*deltaR[ReferenceForce::YIndex];
......@@ -126,6 +114,8 @@ void ReferenceCustomBondIxn::calculateBondIxn(int* atomIndices,
forces[atomBIndex][1] -= dEdR*deltaR[ReferenceForce::YIndex];
forces[atomBIndex][2] -= dEdR*deltaR[ReferenceForce::ZIndex];
for (int i = 0; i < energyParamDerivExpressions.size(); i++)
energyParamDerivs[i] += energyParamDerivExpressions[i].evaluate();
if (totalEnergy != NULL)
*totalEnergy += (RealOpenMM) energyExpression.evaluate();
}
......@@ -86,7 +86,7 @@ void ReferenceCustomTorsionIxn::calculateBondIxn(int* atomIndices,
vector<RealVec>& atomCoordinates,
RealOpenMM* parameters,
vector<RealVec>& forces,
RealOpenMM* totalEnergy) const {
RealOpenMM* totalEnergy, double* energyParamDerivs) {
static const std::string methodName = "\nReferenceCustomTorsionIxn::calculateTorsionIxn";
......
......@@ -74,7 +74,7 @@ void ReferenceHarmonicBondIxn::calculateBondIxn(int* atomIndices,
vector<RealVec>& atomCoordinates,
RealOpenMM* parameters,
vector<RealVec>& forces,
RealOpenMM* totalEnergy) const {
RealOpenMM* totalEnergy, double* energyParamDerivs) {
static const std::string methodName = "\nReferenceHarmonicBondIxn::calculateBondIxn";
......
/* Portions copyright (c) 2006 Stanford University and Simbios.
/* Portions copyright (c) 2006-2016 Stanford University and Simbios.
* Contributors: Pande Group
*
* Permission is hereby granted, free of charge, to any person obtaining
......@@ -81,7 +81,7 @@ ReferenceLJCoulomb14::~ReferenceLJCoulomb14() {
void ReferenceLJCoulomb14::calculateBondIxn(int* atomIndices, vector<RealVec>& atomCoordinates,
RealOpenMM* parameters, vector<RealVec>& forces,
RealOpenMM* totalEnergy) const {
RealOpenMM* totalEnergy, double* energyParamDerivs) {
static const std::string methodName = "\nReferenceLJCoulomb14::calculateBondIxn";
......
......@@ -75,7 +75,7 @@ void ReferenceProperDihedralBond::calculateBondIxn(int* atomIndices,
vector<RealVec>& atomCoordinates,
RealOpenMM* parameters,
vector<RealVec>& forces,
RealOpenMM* totalEnergy) const {
RealOpenMM* totalEnergy, double* energyParamDerivs) {
static const std::string methodName = "\nReferenceProperDihedralBond::calculateBondIxn";
......
......@@ -73,7 +73,7 @@ void ReferenceRbDihedralBond::calculateBondIxn(int* atomIndices,
vector<RealVec>& atomCoordinates,
RealOpenMM* parameters,
vector<RealVec>& forces,
RealOpenMM* totalEnergy) const {
RealOpenMM* totalEnergy, double* energyParamDerivs) {
static const std::string methodName = "\nReferenceRbDihedralBond::calculateBondIxn";
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2015 Stanford University and the Authors. *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -178,6 +178,40 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(0.5*0.8*0.9*0.9, state.getPotentialEnergy(), TOL);
}
void testEnergyParameterDerivatives() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomBondForce* bonds = new CustomBondForce("k*(r-r0)^2");
bonds->addGlobalParameter("r0", 0.0);
bonds->addGlobalParameter("k", 0.0);
bonds->addEnergyParameterDerivative("r0");
bonds->addEnergyParameterDerivative("k");
vector<double> parameters;
bonds->addBond(0, 1, parameters);
bonds->addBond(1, 2, parameters);
system.addForce(bonds);
Context context(system, integrator, platform);
vector<Vec3> positions(3);
positions[0] = Vec3(0, 2, 0);
positions[1] = Vec3(0, 0, 0);
positions[2] = Vec3(1, 0, 0);
context.setPositions(positions);
for (int i = 0; i < 10; i++) {
double r0 = 0.1*i;
double k = 10-i;
context.setParameter("r0", r0);
context.setParameter("k", k);
State state = context.getState(State::ParameterDerivatives);
map<string, double> derivs = state.getEnergyParameterDerivatives();
double dEdr0 = -2*k*((2-r0)+(1-r0));
double dEdk = (2-r0)*(2-r0) + (1-r0)*(1-r0);
ASSERT_EQUAL_TOL(dEdr0, derivs["r0"], 1e-5);
ASSERT_EQUAL_TOL(dEdk, derivs["k"], 1e-5);
}
}
void runPlatformTests();
int main(int argc, char* argv[]) {
......@@ -187,6 +221,7 @@ int main(int argc, char* argv[]) {
testManyParameters();
testIllegalVariable();
testPeriodic();
testEnergyParameterDerivatives();
runPlatformTests();
}
catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment