Commit 9e0f5f3d authored by peastman's avatar peastman
Browse files

Added CustomCVForce.updateParametersInContext()

parent 4885a268
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2018 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -980,6 +980,13 @@ public: ...@@ -980,6 +980,13 @@ public:
* @param innerContext the context created by the CustomCVForce for computing collective variables * @param innerContext the context created by the CustomCVForce for computing collective variables
*/ */
virtual void copyState(ContextImpl& context, ContextImpl& innerContext) = 0; virtual void copyState(ContextImpl& context, ContextImpl& innerContext) = 0;
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the CustomCVForce to copy the parameters from
*/
virtual void copyParametersToContext(ContextImpl& context, const CustomCVForce& force) = 0;
}; };
/** /**
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2017 Stanford University and the Authors. * * Portions copyright (c) 2008-2018 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -255,6 +255,17 @@ public: ...@@ -255,6 +255,17 @@ public:
* @return the inner Context used to evaluate the collective variables * @return the inner Context used to evaluate the collective variables
*/ */
Context& getInnerContext(Context& context); Context& getInnerContext(Context& context);
/**
* Update the tabulated function parameters in a Context to match those stored in this Force object. This method
* provides an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call getTabulatedFunction(index).setFunctionParameters() to modify this object's parameters, then call
* updateParametersInContext() to copy them over to the Context.
*
* This method is very limited. The only information it updates is the parameters of tabulated functions.
* All other aspects of the Force (the energy expression, the set of collective variables, etc.) are unaffected and can
* only be changed by reinitializing the Context.
*/
void updateParametersInContext(Context& context);
/** /**
* Returns whether or not this force makes use of periodic boundary * Returns whether or not this force makes use of periodic boundary
* conditions. * conditions.
......
...@@ -64,6 +64,7 @@ public: ...@@ -64,6 +64,7 @@ public:
std::vector<std::string> getKernelNames(); std::vector<std::string> getKernelNames();
void getCollectiveVariableValues(ContextImpl& context, std::vector<double>& values); void getCollectiveVariableValues(ContextImpl& context, std::vector<double>& values);
Context& getInnerContext(); Context& getInnerContext();
void updateParametersInContext(ContextImpl& context);
private: private:
const CustomCVForce& owner; const CustomCVForce& owner;
Kernel kernel; Kernel kernel;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2017 Stanford University and the Authors. * * Portions copyright (c) 2008-2018 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -152,6 +152,10 @@ Context& CustomCVForce::getInnerContext(Context& context) { ...@@ -152,6 +152,10 @@ Context& CustomCVForce::getInnerContext(Context& context) {
return dynamic_cast<CustomCVForceImpl&>(getImplInContext(context)).getInnerContext(); return dynamic_cast<CustomCVForceImpl&>(getImplInContext(context)).getInnerContext();
} }
void CustomCVForce::updateParametersInContext(Context& context) {
dynamic_cast<CustomCVForceImpl&>(getImplInContext(context)).updateParametersInContext(getContextImpl(context));
}
bool CustomCVForce::usesPeriodicBoundaryConditions() const { bool CustomCVForce::usesPeriodicBoundaryConditions() const {
for (auto& variable : variables) for (auto& variable : variables)
if (variable.variable->usesPeriodicBoundaryConditions()) if (variable.variable->usesPeriodicBoundaryConditions())
......
...@@ -111,3 +111,8 @@ void CustomCVForceImpl::getCollectiveVariableValues(ContextImpl& context, vector ...@@ -111,3 +111,8 @@ void CustomCVForceImpl::getCollectiveVariableValues(ContextImpl& context, vector
Context& CustomCVForceImpl::getInnerContext() { Context& CustomCVForceImpl::getInnerContext() {
return *innerContext; return *innerContext;
} }
void CustomCVForceImpl::updateParametersInContext(ContextImpl& context) {
kernel.getAs<CalcCustomCVForceKernel>().copyParametersToContext(context, owner);
context.systemChanged();
}
...@@ -1259,10 +1259,19 @@ public: ...@@ -1259,10 +1259,19 @@ public:
* @param innerContext the context created by the CustomCVForce for computing collective variables * @param innerContext the context created by the CustomCVForce for computing collective variables
*/ */
void copyState(ContextImpl& context, ContextImpl& innerContext); void copyState(ContextImpl& context, ContextImpl& innerContext);
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the CustomCVForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const CustomCVForce& force);
private: private:
class ReorderListener; class ReorderListener;
void rebuildExpressions(const OpenMM::CustomCVForce& force);
CudaContext& cu; CudaContext& cu;
bool hasInitializedListeners; bool hasInitializedListeners;
std::string energyExpressionText;
Lepton::ExpressionProgram energyExpression; Lepton::ExpressionProgram energyExpression;
std::vector<std::string> variableNames, paramDerivNames, globalParameterNames; std::vector<std::string> variableNames, paramDerivNames, globalParameterNames;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions; std::vector<Lepton::ExpressionProgram> variableDerivExpressions;
......
...@@ -6588,33 +6588,15 @@ void CudaCalcCustomCVForceKernel::initialize(const System& system, const CustomC ...@@ -6588,33 +6588,15 @@ void CudaCalcCustomCVForceKernel::initialize(const System& system, const CustomC
int numCVs = force.getNumCollectiveVariables(); int numCVs = force.getNumCollectiveVariables();
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
energyExpressionText = force.getEnergyFunction();
// Create custom functions for the tabulated functions. for (int i = 0; i < numCVs; i++)
variableNames.push_back(force.getCollectiveVariableName(i));
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions);
energyExpression = energyExpr.createProgram();
for (int i = 0; i < numCVs; i++) {
string name = force.getCollectiveVariableName(i);
variableNames.push_back(name);
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
}
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) { for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string name = force.getEnergyParameterDerivativeName(i); string name = force.getEnergyParameterDerivativeName(i);
paramDerivNames.push_back(name); paramDerivNames.push_back(name);
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
cu.addEnergyParameterDerivative(name); cu.addEnergyParameterDerivative(name);
} }
rebuildExpressions(force);
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
// Copy parameter derivatives from the inner context. // Copy parameter derivatives from the inner context.
...@@ -6730,6 +6712,34 @@ void CudaCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl& i ...@@ -6730,6 +6712,34 @@ void CudaCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl& i
innerContext.setParameter(param.first, context.getParameter(param.first)); innerContext.setParameter(param.first, context.getParameter(param.first));
} }
void CudaCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) {
rebuildExpressions(force);
}
void CudaCalcCustomCVForceKernel::rebuildExpressions(const OpenMM::CustomCVForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(energyExpressionText, functions);
energyExpression = energyExpr.createProgram();
variableDerivExpressions.clear();
for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
paramDerivExpressions.clear();
for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
}
class CudaCalcRMSDForceKernel::ForceInfo : public CudaForceInfo { class CudaCalcRMSDForceKernel::ForceInfo : public CudaForceInfo {
public: public:
ForceInfo(const RMSDForce& force) : force(force) { ForceInfo(const RMSDForce& force) : force(force) {
......
...@@ -1235,10 +1235,19 @@ public: ...@@ -1235,10 +1235,19 @@ public:
* @param innerContext the context created by the CustomCVForce for computing collective variables * @param innerContext the context created by the CustomCVForce for computing collective variables
*/ */
void copyState(ContextImpl& context, ContextImpl& innerContext); void copyState(ContextImpl& context, ContextImpl& innerContext);
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the CustomCVForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const CustomCVForce& force);
private: private:
class ReorderListener; class ReorderListener;
void rebuildExpressions(const OpenMM::CustomCVForce& force);
OpenCLContext& cl; OpenCLContext& cl;
bool hasInitializedKernels; bool hasInitializedKernels;
std::string energyExpressionText;
Lepton::ExpressionProgram energyExpression; Lepton::ExpressionProgram energyExpression;
std::vector<std::string> variableNames, paramDerivNames, globalParameterNames; std::vector<std::string> variableNames, paramDerivNames, globalParameterNames;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions; std::vector<Lepton::ExpressionProgram> variableDerivExpressions;
......
...@@ -6865,33 +6865,15 @@ void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const Custo ...@@ -6865,33 +6865,15 @@ void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const Custo
cl.addForce(new OpenCLForceInfo(1)); cl.addForce(new OpenCLForceInfo(1));
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
energyExpressionText = force.getEnergyFunction();
// Create custom functions for the tabulated functions. for (int i = 0; i < numCVs; i++)
variableNames.push_back(force.getCollectiveVariableName(i));
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions);
energyExpression = energyExpr.createProgram();
for (int i = 0; i < numCVs; i++) {
string name = force.getCollectiveVariableName(i);
variableNames.push_back(name);
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
}
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) { for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string name = force.getEnergyParameterDerivativeName(i); string name = force.getEnergyParameterDerivativeName(i);
paramDerivNames.push_back(name); paramDerivNames.push_back(name);
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
cl.addEnergyParameterDerivative(name); cl.addEnergyParameterDerivative(name);
} }
rebuildExpressions(force);
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
// Copy parameter derivatives from the inner context. // Copy parameter derivatives from the inner context.
...@@ -7021,6 +7003,34 @@ void OpenCLCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl& ...@@ -7021,6 +7003,34 @@ void OpenCLCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl&
innerContext.setParameter(param.first, context.getParameter(param.first)); innerContext.setParameter(param.first, context.getParameter(param.first));
} }
void OpenCLCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) {
rebuildExpressions(force);
}
void OpenCLCalcCustomCVForceKernel::rebuildExpressions(const OpenMM::CustomCVForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(energyExpressionText, functions);
energyExpression = energyExpr.createProgram();
variableDerivExpressions.clear();
for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
paramDerivExpressions.clear();
for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
}
class OpenCLCalcRMSDForceKernel::ForceInfo : public OpenCLForceInfo { class OpenCLCalcRMSDForceKernel::ForceInfo : public OpenCLForceInfo {
public: public:
ForceInfo(const RMSDForce& force) : OpenCLForceInfo(0), force(force) { ForceInfo(const RMSDForce& force) : OpenCLForceInfo(0), force(force) {
......
...@@ -36,6 +36,7 @@ namespace OpenMM { ...@@ -36,6 +36,7 @@ namespace OpenMM {
class ReferenceCustomCVForce { class ReferenceCustomCVForce {
private: private:
std::string energyExpressionText;
Lepton::ExpressionProgram energyExpression; Lepton::ExpressionProgram energyExpression;
std::vector<std::string> variableNames, paramDerivNames; std::vector<std::string> variableNames, paramDerivNames;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions; std::vector<Lepton::ExpressionProgram> variableDerivExpressions;
...@@ -52,6 +53,13 @@ public: ...@@ -52,6 +53,13 @@ public:
*/ */
~ReferenceCustomCVForce(); ~ReferenceCustomCVForce();
/**
* Create the ExpressionPrograms. This is called automatically when the object is
* created. It can be called again to rebuild them if the user calls
* updateParametersInContext().
*/
void rebuildExpressions(const OpenMM::CustomCVForce& force);
/** /**
* Calculate the interaction. * Calculate the interaction.
* *
......
...@@ -1037,6 +1037,13 @@ public: ...@@ -1037,6 +1037,13 @@ public:
* @param innerContext the context created by the CustomCVForce for computing collective variables * @param innerContext the context created by the CustomCVForce for computing collective variables
*/ */
void copyState(ContextImpl& context, ContextImpl& innerContext); void copyState(ContextImpl& context, ContextImpl& innerContext);
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the CustomCVForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const CustomCVForce& force);
private: private:
ReferenceCustomCVForce* ixn; ReferenceCustomCVForce* ixn;
std::vector<std::string> globalParameterNames, energyParamDerivNames; std::vector<std::string> globalParameterNames, energyParamDerivNames;
......
...@@ -1993,6 +1993,10 @@ void ReferenceCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextIm ...@@ -1993,6 +1993,10 @@ void ReferenceCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextIm
innerContext.setParameter(param.first, context.getParameter(param.first)); innerContext.setParameter(param.first, context.getParameter(param.first));
} }
void ReferenceCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) {
ixn->rebuildExpressions(force);
}
void ReferenceCalcRMSDForceKernel::initialize(const System& system, const RMSDForce& force) { void ReferenceCalcRMSDForceKernel::initialize(const System& system, const RMSDForce& force) {
particles = force.getParticles(); particles = force.getParticles();
if (particles.size() == 0) if (particles.size() == 0)
......
...@@ -33,6 +33,15 @@ using namespace OpenMM; ...@@ -33,6 +33,15 @@ using namespace OpenMM;
using namespace std; using namespace std;
ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) { ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
energyExpressionText = force.getEnergyFunction();
for (int i = 0; i < force.getNumCollectiveVariables(); i++)
variableNames.push_back(force.getCollectiveVariableName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
paramDerivNames.push_back(force.getEnergyParameterDerivativeName(i));
rebuildExpressions(force);
}
void ReferenceCustomCVForce::rebuildExpressions(const OpenMM::CustomCVForce& force) {
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
...@@ -41,18 +50,14 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) { ...@@ -41,18 +50,14 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
// Create the expressions. // Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions); Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(energyExpressionText, functions);
energyExpression = energyExpr.createProgram(); energyExpression = energyExpr.createProgram();
for (int i = 0; i < force.getNumCollectiveVariables(); i++) { variableDerivExpressions.clear();
string name = force.getCollectiveVariableName(i); for (auto& name : variableNames)
variableNames.push_back(name);
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram()); variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
} paramDerivExpressions.clear();
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) { for (auto& name : paramDerivNames)
string name = force.getEnergyParameterDerivativeName(i);
paramDerivNames.push_back(name);
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram()); paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
}
// Delete the custom functions. // Delete the custom functions.
......
...@@ -159,6 +159,8 @@ void testTabulatedFunction() { ...@@ -159,6 +159,8 @@ void testTabulatedFunction() {
system.addForce(cv); system.addForce(cv);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(1); vector<Vec3> positions(1);
double scale = 1.0;
for (int i = 0; i < 2; i++) {
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) { for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) { for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
positions[0] = Vec3(x, y, 1.5); positions[0] = Vec3(x, y, 1.5);
...@@ -168,14 +170,24 @@ void testTabulatedFunction() { ...@@ -168,14 +170,24 @@ void testTabulatedFunction() {
double energy = 1; double energy = 1;
Vec3 force(0, 0, 0); Vec3 force(0, 0, 0);
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) { if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) {
energy = sin(0.25*x)*cos(0.33*y)+1; energy = scale*sin(0.25*x)*cos(0.33*y)+1;
force[0] = -0.25*cos(0.25*x)*cos(0.33*y); force[0] = -scale*0.25*cos(0.25*x)*cos(0.33*y);
force[1] = 0.3*sin(0.25*x)*sin(0.33*y); force[1] = scale*0.3*sin(0.25*x)*sin(0.33*y);
} }
ASSERT_EQUAL_VEC(force, forces[0], 0.1); ASSERT_EQUAL_VEC(force, forces[0], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
} }
} }
// Now update the tabulated function, call updateParametersInContext(),
// and see if it's still correct.
for (int i = 0; i < table.size(); i++)
table[i] *= 2;
dynamic_cast<Continuous2DFunction&>(cv->getTabulatedFunction(0)).setFunctionParameters(xsize, ysize, table, xmin, xmax, ymin, ymax);
cv->updateParametersInContext(context);
scale *= 2.0;
}
} }
void testReordering() { void testReordering() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment