Unverified Commit 5387707d authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2122 from peastman/updatetable

Added CustomCVForce.updateParametersInContext()
parents 4885a268 9e0f5f3d
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2018 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -980,6 +980,13 @@ public:
* @param innerContext the context created by the CustomCVForce for computing collective variables
*/
virtual void copyState(ContextImpl& context, ContextImpl& innerContext) = 0;
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the CustomCVForce to copy the parameters from
*/
virtual void copyParametersToContext(ContextImpl& context, const CustomCVForce& force) = 0;
};
/**
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2017 Stanford University and the Authors. *
* Portions copyright (c) 2008-2018 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -255,6 +255,17 @@ public:
* @return the inner Context used to evaluate the collective variables
*/
Context& getInnerContext(Context& context);
/**
* Update the tabulated function parameters in a Context to match those stored in this Force object. This method
* provides an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call getTabulatedFunction(index).setFunctionParameters() to modify this object's parameters, then call
* updateParametersInContext() to copy them over to the Context.
*
* This method is very limited. The only information it updates is the parameters of tabulated functions.
* All other aspects of the Force (the energy expression, the set of collective variables, etc.) are unaffected and can
* only be changed by reinitializing the Context.
*/
void updateParametersInContext(Context& context);
/**
* Returns whether or not this force makes use of periodic boundary
* conditions.
......
......@@ -64,6 +64,7 @@ public:
std::vector<std::string> getKernelNames();
void getCollectiveVariableValues(ContextImpl& context, std::vector<double>& values);
Context& getInnerContext();
void updateParametersInContext(ContextImpl& context);
private:
const CustomCVForce& owner;
Kernel kernel;
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2017 Stanford University and the Authors. *
* Portions copyright (c) 2008-2018 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -152,6 +152,10 @@ Context& CustomCVForce::getInnerContext(Context& context) {
return dynamic_cast<CustomCVForceImpl&>(getImplInContext(context)).getInnerContext();
}
void CustomCVForce::updateParametersInContext(Context& context) {
dynamic_cast<CustomCVForceImpl&>(getImplInContext(context)).updateParametersInContext(getContextImpl(context));
}
bool CustomCVForce::usesPeriodicBoundaryConditions() const {
for (auto& variable : variables)
if (variable.variable->usesPeriodicBoundaryConditions())
......
......@@ -111,3 +111,8 @@ void CustomCVForceImpl::getCollectiveVariableValues(ContextImpl& context, vector
Context& CustomCVForceImpl::getInnerContext() {
return *innerContext;
}
void CustomCVForceImpl::updateParametersInContext(ContextImpl& context) {
kernel.getAs<CalcCustomCVForceKernel>().copyParametersToContext(context, owner);
context.systemChanged();
}
......@@ -1259,10 +1259,19 @@ public:
* @param innerContext the context created by the CustomCVForce for computing collective variables
*/
void copyState(ContextImpl& context, ContextImpl& innerContext);
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the CustomCVForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const CustomCVForce& force);
private:
class ReorderListener;
void rebuildExpressions(const OpenMM::CustomCVForce& force);
CudaContext& cu;
bool hasInitializedListeners;
std::string energyExpressionText;
Lepton::ExpressionProgram energyExpression;
std::vector<std::string> variableNames, paramDerivNames, globalParameterNames;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions;
......
......@@ -6588,33 +6588,15 @@ void CudaCalcCustomCVForceKernel::initialize(const System& system, const CustomC
int numCVs = force.getNumCollectiveVariables();
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions);
energyExpression = energyExpr.createProgram();
for (int i = 0; i < numCVs; i++) {
string name = force.getCollectiveVariableName(i);
variableNames.push_back(name);
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
}
energyExpressionText = force.getEnergyFunction();
for (int i = 0; i < numCVs; i++)
variableNames.push_back(force.getCollectiveVariableName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string name = force.getEnergyParameterDerivativeName(i);
paramDerivNames.push_back(name);
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
cu.addEnergyParameterDerivative(name);
}
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
rebuildExpressions(force);
// Copy parameter derivatives from the inner context.
......@@ -6730,6 +6712,34 @@ void CudaCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl& i
innerContext.setParameter(param.first, context.getParameter(param.first));
}
void CudaCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) {
rebuildExpressions(force);
}
void CudaCalcCustomCVForceKernel::rebuildExpressions(const OpenMM::CustomCVForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(energyExpressionText, functions);
energyExpression = energyExpr.createProgram();
variableDerivExpressions.clear();
for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
paramDerivExpressions.clear();
for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
}
class CudaCalcRMSDForceKernel::ForceInfo : public CudaForceInfo {
public:
ForceInfo(const RMSDForce& force) : force(force) {
......
......@@ -1235,10 +1235,19 @@ public:
* @param innerContext the context created by the CustomCVForce for computing collective variables
*/
void copyState(ContextImpl& context, ContextImpl& innerContext);
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the CustomCVForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const CustomCVForce& force);
private:
class ReorderListener;
void rebuildExpressions(const OpenMM::CustomCVForce& force);
OpenCLContext& cl;
bool hasInitializedKernels;
std::string energyExpressionText;
Lepton::ExpressionProgram energyExpression;
std::vector<std::string> variableNames, paramDerivNames, globalParameterNames;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions;
......
......@@ -6865,34 +6865,16 @@ void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const Custo
cl.addForce(new OpenCLForceInfo(1));
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions);
energyExpression = energyExpr.createProgram();
for (int i = 0; i < numCVs; i++) {
string name = force.getCollectiveVariableName(i);
variableNames.push_back(name);
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
}
energyExpressionText = force.getEnergyFunction();
for (int i = 0; i < numCVs; i++)
variableNames.push_back(force.getCollectiveVariableName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string name = force.getEnergyParameterDerivativeName(i);
paramDerivNames.push_back(name);
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
cl.addEnergyParameterDerivative(name);
}
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
rebuildExpressions(force);
// Copy parameter derivatives from the inner context.
OpenCLContext& cl2 = *reinterpret_cast<OpenCLPlatform::PlatformData*>(innerContext.getPlatformData())->contexts[0];
......@@ -7021,6 +7003,34 @@ void OpenCLCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl&
innerContext.setParameter(param.first, context.getParameter(param.first));
}
void OpenCLCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) {
rebuildExpressions(force);
}
void OpenCLCalcCustomCVForceKernel::rebuildExpressions(const OpenMM::CustomCVForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(energyExpressionText, functions);
energyExpression = energyExpr.createProgram();
variableDerivExpressions.clear();
for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
paramDerivExpressions.clear();
for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
}
class OpenCLCalcRMSDForceKernel::ForceInfo : public OpenCLForceInfo {
public:
ForceInfo(const RMSDForce& force) : OpenCLForceInfo(0), force(force) {
......
......@@ -36,6 +36,7 @@ namespace OpenMM {
class ReferenceCustomCVForce {
private:
std::string energyExpressionText;
Lepton::ExpressionProgram energyExpression;
std::vector<std::string> variableNames, paramDerivNames;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions;
......@@ -52,6 +53,13 @@ public:
*/
~ReferenceCustomCVForce();
/**
* Create the ExpressionPrograms. This is called automatically when the object is
* created. It can be called again to rebuild them if the user calls
* updateParametersInContext().
*/
void rebuildExpressions(const OpenMM::CustomCVForce& force);
/**
* Calculate the interaction.
*
......
......@@ -1037,6 +1037,13 @@ public:
* @param innerContext the context created by the CustomCVForce for computing collective variables
*/
void copyState(ContextImpl& context, ContextImpl& innerContext);
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the CustomCVForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const CustomCVForce& force);
private:
ReferenceCustomCVForce* ixn;
std::vector<std::string> globalParameterNames, energyParamDerivNames;
......
......@@ -1993,6 +1993,10 @@ void ReferenceCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextIm
innerContext.setParameter(param.first, context.getParameter(param.first));
}
void ReferenceCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) {
ixn->rebuildExpressions(force);
}
void ReferenceCalcRMSDForceKernel::initialize(const System& system, const RMSDForce& force) {
particles = force.getParticles();
if (particles.size() == 0)
......
......@@ -33,6 +33,15 @@ using namespace OpenMM;
using namespace std;
ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
energyExpressionText = force.getEnergyFunction();
for (int i = 0; i < force.getNumCollectiveVariables(); i++)
variableNames.push_back(force.getCollectiveVariableName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
paramDerivNames.push_back(force.getEnergyParameterDerivativeName(i));
rebuildExpressions(force);
}
void ReferenceCustomCVForce::rebuildExpressions(const OpenMM::CustomCVForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
......@@ -41,18 +50,14 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions);
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(energyExpressionText, functions);
energyExpression = energyExpr.createProgram();
for (int i = 0; i < force.getNumCollectiveVariables(); i++) {
string name = force.getCollectiveVariableName(i);
variableNames.push_back(name);
variableDerivExpressions.clear();
for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
}
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
string name = force.getEnergyParameterDerivativeName(i);
paramDerivNames.push_back(name);
paramDerivExpressions.clear();
for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
}
// Delete the custom functions.
......
......@@ -159,22 +159,34 @@ void testTabulatedFunction() {
system.addForce(cv);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
positions[0] = Vec3(x, y, 1.5);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
Vec3 force(0, 0, 0);
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) {
energy = sin(0.25*x)*cos(0.33*y)+1;
force[0] = -0.25*cos(0.25*x)*cos(0.33*y);
force[1] = 0.3*sin(0.25*x)*sin(0.33*y);
double scale = 1.0;
for (int i = 0; i < 2; i++) {
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
positions[0] = Vec3(x, y, 1.5);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
Vec3 force(0, 0, 0);
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) {
energy = scale*sin(0.25*x)*cos(0.33*y)+1;
force[0] = -scale*0.25*cos(0.25*x)*cos(0.33*y);
force[1] = scale*0.3*sin(0.25*x)*sin(0.33*y);
}
ASSERT_EQUAL_VEC(force, forces[0], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
ASSERT_EQUAL_VEC(force, forces[0], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
// Now update the tabulated function, call updateParametersInContext(),
// and see if it's still correct.
for (int i = 0; i < table.size(); i++)
table[i] *= 2;
dynamic_cast<Continuous2DFunction&>(cv->getTabulatedFunction(0)).setFunctionParameters(xsize, ysize, table, xmin, xmax, ymin, ymax);
cv->updateParametersInContext(context);
scale *= 2.0;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment