Commit fc79cc94 authored by peastman's avatar peastman
Browse files

Optimized CustomCVForce.updateParametersInContext()

parent 21a09873
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors. *
* Portions copyright (c) 2009-2018 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -65,6 +65,14 @@ public:
* Get an Operation in this program.
*/
const Operation& getOperation(int index) const;
/**
* Change an Operation in this program.
*
* The Operation must have been allocated on the heap with the "new" operator.
* The ExpressionProgram assumes ownership of it and will delete it when it
* is no longer needed.
*/
void setOperation(int index, Operation* operation);
/**
* Get the size of the stack needed to execute this program. This is the largest number of elements present
* on the stack at any point during evaluation.
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Portions copyright (c) 2009-2018 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -226,6 +226,11 @@ class LEPTON_EXPORT Operation::Custom : public Operation {
public:
Custom(const std::string& name, CustomFunction* function) : name(name), function(function), isDerivative(false), derivOrder(function->getNumArguments(), 0) {
}
Custom(const std::string& name, CustomFunction* function, const std::vector<int>& derivOrder) : name(name), function(function), isDerivative(false), derivOrder(derivOrder) {
for (int order : derivOrder)
if (order != 0)
isDerivative = true;
}
Custom(const Custom& base, int derivIndex) : name(base.name), function(base.function->clone()), isDerivative(true), derivOrder(base.derivOrder) {
derivOrder[derivIndex]++;
}
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2013 Stanford University and the Authors. *
* Portions copyright (c) 2009-2018 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -84,6 +84,11 @@ const Operation& ExpressionProgram::getOperation(int index) const {
return *operations[index];
}
void ExpressionProgram::setOperation(int index, Operation* operation) {
delete operations[index];
operations[index] = operation;
}
int ExpressionProgram::getStackSize() const {
return stackSize;
}
......
......@@ -1268,10 +1268,8 @@ public:
void copyParametersToContext(ContextImpl& context, const CustomCVForce& force);
private:
class ReorderListener;
void rebuildExpressions(const OpenMM::CustomCVForce& force);
CudaContext& cu;
bool hasInitializedListeners;
std::string energyExpressionText;
Lepton::ExpressionProgram energyExpression;
std::vector<std::string> variableNames, paramDerivNames, globalParameterNames;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions;
......
......@@ -93,6 +93,15 @@ static pair<ExpressionTreeNode, string> makeVariable(const string& name, const s
return make_pair(ExpressionTreeNode(new Operation::Variable(name)), value);
}
static void replaceFunctionsInExpression(map<string, CustomFunction*>& functions, ExpressionProgram& expression) {
for (int i = 0; i < expression.getNumOperations(); i++) {
if (expression.getOperation(i).getId() == Operation::CUSTOM) {
const Operation::Custom& op = dynamic_cast<const Operation::Custom&>(expression.getOperation(i));
expression.setOperation(i, new Operation::Custom(op.getName(), functions[op.getName()]->clone(), op.getDerivOrder()));
}
}
}
void CudaCalcForcesAndEnergyKernel::initialize(const System& system) {
}
......@@ -6588,7 +6597,6 @@ void CudaCalcCustomCVForceKernel::initialize(const System& system, const CustomC
int numCVs = force.getNumCollectiveVariables();
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
energyExpressionText = force.getEnergyFunction();
for (int i = 0; i < numCVs; i++)
variableNames.push_back(force.getCollectiveVariableName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
......@@ -6596,7 +6604,28 @@ void CudaCalcCustomCVForceKernel::initialize(const System& system, const CustomC
paramDerivNames.push_back(name);
cu.addEnergyParameterDerivative(name);
}
rebuildExpressions(force);
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions);
energyExpression = energyExpr.createProgram();
variableDerivExpressions.clear();
for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
paramDerivExpressions.clear();
for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
// Copy parameter derivatives from the inner context.
......@@ -6713,26 +6742,19 @@ void CudaCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl& i
}
void CudaCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) {
rebuildExpressions(force);
}
void CudaCalcCustomCVForceKernel::rebuildExpressions(const OpenMM::CustomCVForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
map<string, CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
// Replace tabulated functions in the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(energyExpressionText, functions);
energyExpression = energyExpr.createProgram();
variableDerivExpressions.clear();
for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
paramDerivExpressions.clear();
for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
replaceFunctionsInExpression(functions, energyExpression);
for (auto& expression : variableDerivExpressions)
replaceFunctionsInExpression(functions, expression);
for (auto& expression : paramDerivExpressions)
replaceFunctionsInExpression(functions, expression);
// Delete the custom functions.
......
......@@ -1244,10 +1244,8 @@ public:
void copyParametersToContext(ContextImpl& context, const CustomCVForce& force);
private:
class ReorderListener;
void rebuildExpressions(const OpenMM::CustomCVForce& force);
OpenCLContext& cl;
bool hasInitializedKernels;
std::string energyExpressionText;
Lepton::ExpressionProgram energyExpression;
std::vector<std::string> variableNames, paramDerivNames, globalParameterNames;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions;
......
......@@ -117,6 +117,15 @@ static pair<ExpressionTreeNode, string> makeVariable(const string& name, const s
return make_pair(ExpressionTreeNode(new Operation::Variable(name)), value);
}
static void replaceFunctionsInExpression(map<string, CustomFunction*>& functions, ExpressionProgram& expression) {
for (int i = 0; i < expression.getNumOperations(); i++) {
if (expression.getOperation(i).getId() == Operation::CUSTOM) {
const Operation::Custom& op = dynamic_cast<const Operation::Custom&>(expression.getOperation(i));
expression.setOperation(i, new Operation::Custom(op.getName(), functions[op.getName()]->clone(), op.getDerivOrder()));
}
}
}
void OpenCLCalcForcesAndEnergyKernel::initialize(const System& system) {
}
......@@ -6865,7 +6874,6 @@ void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const Custo
cl.addForce(new OpenCLForceInfo(1));
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
energyExpressionText = force.getEnergyFunction();
for (int i = 0; i < numCVs; i++)
variableNames.push_back(force.getCollectiveVariableName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
......@@ -6873,7 +6881,28 @@ void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const Custo
paramDerivNames.push_back(name);
cl.addEnergyParameterDerivative(name);
}
rebuildExpressions(force);
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions);
energyExpression = energyExpr.createProgram();
variableDerivExpressions.clear();
for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
paramDerivExpressions.clear();
for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
// Copy parameter derivatives from the inner context.
......@@ -7004,26 +7033,19 @@ void OpenCLCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl&
}
void OpenCLCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) {
rebuildExpressions(force);
}
void OpenCLCalcCustomCVForceKernel::rebuildExpressions(const OpenMM::CustomCVForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
map<string, CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
// Replace tabulated functions in the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(energyExpressionText, functions);
energyExpression = energyExpr.createProgram();
variableDerivExpressions.clear();
for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
paramDerivExpressions.clear();
for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
replaceFunctionsInExpression(functions, energyExpression);
for (auto& expression : variableDerivExpressions)
replaceFunctionsInExpression(functions, expression);
for (auto& expression : paramDerivExpressions)
replaceFunctionsInExpression(functions, expression);
// Delete the custom functions.
......
......@@ -36,7 +36,6 @@ namespace OpenMM {
class ReferenceCustomCVForce {
private:
std::string energyExpressionText;
Lepton::ExpressionProgram energyExpression;
std::vector<std::string> variableNames, paramDerivNames;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions;
......@@ -54,11 +53,10 @@ public:
~ReferenceCustomCVForce();
/**
* Create the ExpressionPrograms. This is called automatically when the object is
* created. It can be called again to rebuild them if the user calls
* Update any tabulated functions used by the force. This is called when the user calls
* updateParametersInContext().
*/
void rebuildExpressions(const OpenMM::CustomCVForce& force);
void updateTabulatedFunctions(const OpenMM::CustomCVForce& force);
/**
* Calculate the interaction.
......
......@@ -1994,7 +1994,7 @@ void ReferenceCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextIm
}
void ReferenceCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) {
ixn->rebuildExpressions(force);
ixn->updateTabulatedFunctions(force);
}
void ReferenceCalcRMSDForceKernel::initialize(const System& system, const RMSDForce& force) {
......
......@@ -28,29 +28,27 @@
#include "lepton/CustomFunction.h"
#include "lepton/ParsedExpression.h"
#include "lepton/Parser.h"
#include "lepton/Operation.h"
using namespace OpenMM;
using namespace Lepton;
using namespace std;
ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
energyExpressionText = force.getEnergyFunction();
for (int i = 0; i < force.getNumCollectiveVariables(); i++)
variableNames.push_back(force.getCollectiveVariableName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
paramDerivNames.push_back(force.getEnergyParameterDerivativeName(i));
rebuildExpressions(force);
}
void ReferenceCustomCVForce::rebuildExpressions(const OpenMM::CustomCVForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
map<string, CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(energyExpressionText, functions);
ParsedExpression energyExpr = Parser::parse(force.getEnergyFunction(), functions);
energyExpression = energyExpr.createProgram();
variableDerivExpressions.clear();
for (auto& name : variableNames)
......@@ -65,6 +63,36 @@ void ReferenceCustomCVForce::rebuildExpressions(const OpenMM::CustomCVForce& for
delete function.second;
}
static void replaceFunctionsInExpression(map<string, CustomFunction*>& functions, ExpressionProgram& expression) {
for (int i = 0; i < expression.getNumOperations(); i++) {
if (expression.getOperation(i).getId() == Operation::CUSTOM) {
const Operation::Custom& op = dynamic_cast<const Operation::Custom&>(expression.getOperation(i));
expression.setOperation(i, new Operation::Custom(op.getName(), functions[op.getName()]->clone(), op.getDerivOrder()));
}
}
}
void ReferenceCustomCVForce::updateTabulatedFunctions(const OpenMM::CustomCVForce& force) {
// Create custom functions for the tabulated functions.
map<string, CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Replace tabulated functions in the expressions.
replaceFunctionsInExpression(functions, energyExpression);
for (auto& expression : variableDerivExpressions)
replaceFunctionsInExpression(functions, expression);
for (auto& expression : paramDerivExpressions)
replaceFunctionsInExpression(functions, expression);
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
}
ReferenceCustomCVForce::~ReferenceCustomCVForce() {
}
......
......@@ -182,8 +182,8 @@ void testTabulatedFunction() {
// Now update the tabulated function, call updateParametersInContext(),
// and see if it's still correct.
for (int i = 0; i < table.size(); i++)
table[i] *= 2;
for (int j = 0; j < table.size(); j++)
table[j] *= 2;
dynamic_cast<Continuous2DFunction&>(cv->getTabulatedFunction(0)).setFunctionParameters(xsize, ysize, table, xmin, xmax, ymin, ymax);
cv->updateParametersInContext(context);
scale *= 2.0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment