Unverified Commit 17b61225 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Use CompiledExpression for CustomCVForce energy expression (#3898)

parent b57a5a63
......@@ -909,6 +909,7 @@ public:
CommonCalcCustomCVForceKernel(std::string name, const Platform& platform, ComputeContext& cc) : CalcCustomCVForceKernel(name, platform),
cc(cc), hasInitializedListeners(false) {
}
~CommonCalcCustomCVForceKernel();
/**
* Initialize the kernel.
*
......@@ -948,13 +949,16 @@ public:
private:
class ForceInfo;
class ReorderListener;
class TabulatedFunctionWrapper;
ComputeContext& cc;
bool hasInitializedListeners;
Lepton::ExpressionProgram energyExpression;
Lepton::CompiledExpression energyExpression;
std::vector<std::string> variableNames, paramDerivNames, globalParameterNames;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions;
std::vector<Lepton::ExpressionProgram> paramDerivExpressions;
std::vector<Lepton::CompiledExpression> variableDerivExpressions;
std::vector<Lepton::CompiledExpression> paramDerivExpressions;
std::vector<ComputeArray> cvForces;
std::vector<double> globalValues, cvValues;
std::vector<Lepton::CustomFunction*> tabulatedFunctions;
ComputeArray invAtomOrder;
ComputeArray innerInvAtomOrder;
ComputeKernel copyStateKernel, copyForcesKernel, addForcesKernel;
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2022 Stanford University and the Authors. *
* Portions copyright (c) 2008-2023 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -98,15 +98,6 @@ static pair<ExpressionTreeNode, string> makeVariable(const string& name, const s
return make_pair(ExpressionTreeNode(new Operation::Variable(name)), value);
}
static void replaceFunctionsInExpression(map<string, CustomFunction*>& functions, ExpressionProgram& expression) {
for (int i = 0; i < expression.getNumOperations(); i++) {
if (expression.getOperation(i).getId() == Operation::CUSTOM) {
const Operation::Custom& op = dynamic_cast<const Operation::Custom&>(expression.getOperation(i));
expression.setOperation(i, new Operation::Custom(op.getName(), functions[op.getName()]->clone(), op.getDerivOrder()));
}
}
}
void CommonApplyConstraintsKernel::initialize(const System& system) {
}
......@@ -5136,6 +5127,30 @@ private:
ArrayInterface& invAtomOrder;
};
// This class allows us to update tabulated functions without having to recompile expressions
// that use them.
class CommonCalcCustomCVForceKernel::TabulatedFunctionWrapper : public CustomFunction {
public:
TabulatedFunctionWrapper(vector<Lepton::CustomFunction*>& tabulatedFunctions, int index) :
tabulatedFunctions(tabulatedFunctions), index(index) {
}
int getNumArguments() const {
return tabulatedFunctions[index]->getNumArguments();
}
double evaluate(const double* arguments) const {
return tabulatedFunctions[index]->evaluate(arguments);
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return tabulatedFunctions[index]->evaluateDerivative(arguments, derivOrder);
}
CustomFunction* clone() const {
return new TabulatedFunctionWrapper(tabulatedFunctions, index);
}
private:
vector<Lepton::CustomFunction*>& tabulatedFunctions;
int index;
};
void CommonCalcCustomCVForceKernel::initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext) {
ContextSelector selector(cc);
int numCVs = force.getNumCollectiveVariables();
......@@ -5152,19 +5167,34 @@ void CommonCalcCustomCVForceKernel::initialize(const System& system, const Custo
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
tabulatedFunctions.resize(force.getNumTabulatedFunctions(), NULL);
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
tabulatedFunctions[i] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
functions[force.getTabulatedFunctionName(i)] = new TabulatedFunctionWrapper(tabulatedFunctions, i);
}
// Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions);
energyExpression = energyExpr.createProgram();
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
energyExpression = energyExpr.createCompiledExpression();
variableDerivExpressions.clear();
for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
variableDerivExpressions.push_back(energyExpr.differentiate(name).createCompiledExpression());
paramDerivExpressions.clear();
for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
paramDerivExpressions.push_back(energyExpr.differentiate(name).createCompiledExpression());
globalValues.resize(globalParameterNames.size());
cvValues.resize(numCVs);
map<string, double*> variableLocations;
for (int i = 0; i < globalParameterNames.size(); i++)
variableLocations[globalParameterNames[i]] = &globalValues[i];
for (int i = 0; i < numCVs; i++)
variableLocations[variableNames[i]] = &cvValues[i];
energyExpression.setVariableLocations(variableLocations);
for (CompiledExpression& expr : variableDerivExpressions)
expr.setVariableLocations(variableLocations);
for (CompiledExpression& expr : paramDerivExpressions)
expr.setVariableLocations(variableLocations);
// Delete the custom functions.
......@@ -5229,15 +5259,20 @@ void CommonCalcCustomCVForceKernel::initialize(const System& system, const Custo
cc.addForce(new ForceInfo(*info));
}
CommonCalcCustomCVForceKernel::~CommonCalcCustomCVForceKernel() {
for (int i = 0; i < tabulatedFunctions.size(); i++)
if (tabulatedFunctions[i] != NULL)
delete tabulatedFunctions[i];
}
double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl& innerContext, bool includeForces, bool includeEnergy) {
copyState(context, innerContext);
int numCVs = variableNames.size();
int numAtoms = cc.getNumAtoms();
int paddedNumAtoms = cc.getPaddedNumAtoms();
vector<double> cvValues;
vector<map<string, double> > cvDerivs(numCVs);
for (int i = 0; i < numCVs; i++) {
cvValues.push_back(innerContext.calcForcesAndEnergy(true, true, 1<<i));
cvValues[i] = innerContext.calcForcesAndEnergy(true, true, 1<<i);
ContextSelector selector(cc);
copyForcesKernel->setArg(0, cvForces[i]);
copyForcesKernel->execute(numAtoms);
......@@ -5247,14 +5282,11 @@ double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl&
// Compute the energy and forces.
ContextSelector selector(cc);
map<string, double> variables;
for (auto& name : globalParameterNames)
variables[name] = context.getParameter(name);
for (int i = 0; i < numCVs; i++)
variables[variableNames[i]] = cvValues[i];
double energy = energyExpression.evaluate(variables);
for (int i = 0; i < globalParameterNames.size(); i++)
globalValues[i] = context.getParameter(globalParameterNames[i]);
double energy = energyExpression.evaluate();
for (int i = 0; i < numCVs; i++) {
double dEdV = variableDerivExpressions[i].evaluate(variables);
double dEdV = variableDerivExpressions[i].evaluate();
addForcesKernel->setArg(2*i+2, cvForces[i]);
if (cc.getUseDoublePrecision())
addForcesKernel->setArg(2*i+3, dEdV);
......@@ -5265,14 +5297,16 @@ double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl&
// Compute the energy parameter derivatives.
if (paramDerivExpressions.size() > 0) {
map<string, double>& energyParamDerivs = cc.getEnergyParamDerivWorkspace();
for (int i = 0; i < paramDerivExpressions.size(); i++)
energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate(variables);
energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate();
for (int i = 0; i < numCVs; i++) {
double dEdV = variableDerivExpressions[i].evaluate(variables);
double dEdV = variableDerivExpressions[i].evaluate();
for (auto& deriv : cvDerivs[i])
energyParamDerivs[deriv.first] += dEdV*deriv.second;
}
}
return energy;
}
......@@ -5305,22 +5339,13 @@ void CommonCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl&
void CommonCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) {
// Create custom functions for the tabulated functions.
map<string, CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Replace tabulated functions in the expressions.
replaceFunctionsInExpression(functions, energyExpression);
for (auto& expression : variableDerivExpressions)
replaceFunctionsInExpression(functions, expression);
for (auto& expression : paramDerivExpressions)
replaceFunctionsInExpression(functions, expression);
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
if (tabulatedFunctions[i] != NULL) {
delete tabulatedFunctions[i];
tabulatedFunctions[i] = NULL;
}
tabulatedFunctions[i] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
}
}
void CommonIntegrateVerletStepKernel::initialize(const System& system, const VerletIntegrator& integrator) {
......
/* Portions copyright (c) 2017 Stanford University and Simbios.
/* Portions copyright (c) 2017-2023 Stanford University and Simbios.
* Contributors: Peter Eastman
*
* Permission is hereby granted, free of charge, to any person obtaining
......@@ -27,7 +27,8 @@
#include "openmm/CustomCVForce.h"
#include "openmm/internal/ContextImpl.h"
#include "lepton/ExpressionProgram.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CustomFunction.h"
#include <map>
#include <string>
#include <vector>
......@@ -36,10 +37,13 @@ namespace OpenMM {
class ReferenceCustomCVForce {
private:
Lepton::ExpressionProgram energyExpression;
std::vector<std::string> variableNames, paramDerivNames;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions;
std::vector<Lepton::ExpressionProgram> paramDerivExpressions;
class TabulatedFunctionWrapper;
Lepton::CompiledExpression energyExpression;
std::vector<std::string> variableNames, paramDerivNames, globalParameterNames;
std::vector<Lepton::CompiledExpression> variableDerivExpressions;
std::vector<Lepton::CompiledExpression> paramDerivExpressions;
std::vector<double> globalValues, cvValues;
std::vector<Lepton::CustomFunction*> tabulatedFunctions;
public:
/**
......@@ -70,7 +74,7 @@ public:
*/
void calculateIxn(ContextImpl& innerContext, std::vector<OpenMM::Vec3>& atomCoordinates,
const std::map<std::string, double>& globalParameters,
std::vector<OpenMM::Vec3>& forces, double* totalEnergy, std::map<std::string, double>& energyParamDerivs) const;
std::vector<OpenMM::Vec3>& forces, double* totalEnergy, std::map<std::string, double>& energyParamDerivs);
};
} // namespace OpenMM
......
/* Portions copyright (c) 2009-2017 Stanford University and Simbios.
/* Portions copyright (c) 2009-2023 Stanford University and Simbios.
* Contributors: Peter Eastman
*
* Permission is hereby granted, free of charge, to any person obtaining
......@@ -34,8 +34,35 @@ using namespace OpenMM;
using namespace Lepton;
using namespace std;
// This class allows us to update tabulated functions without having to recompile expressions
// that use them.
class ReferenceCustomCVForce::TabulatedFunctionWrapper : public CustomFunction {
public:
TabulatedFunctionWrapper(vector<Lepton::CustomFunction*>& tabulatedFunctions, int index) :
tabulatedFunctions(tabulatedFunctions), index(index) {
}
int getNumArguments() const {
return tabulatedFunctions[index]->getNumArguments();
}
double evaluate(const double* arguments) const {
return tabulatedFunctions[index]->evaluate(arguments);
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return tabulatedFunctions[index]->evaluateDerivative(arguments, derivOrder);
}
CustomFunction* clone() const {
return new TabulatedFunctionWrapper(tabulatedFunctions, index);
}
private:
vector<Lepton::CustomFunction*>& tabulatedFunctions;
int index;
};
ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
for (int i = 0; i < force.getNumCollectiveVariables(); i++)
int numCVs = force.getNumCollectiveVariables();
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
for (int i = 0; i < numCVs; i++)
variableNames.push_back(force.getCollectiveVariableName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
paramDerivNames.push_back(force.getEnergyParameterDerivativeName(i));
......@@ -43,19 +70,34 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
// Create custom functions for the tabulated functions.
map<string, CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
tabulatedFunctions.resize(force.getNumTabulatedFunctions(), NULL);
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
tabulatedFunctions[i] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
functions[force.getTabulatedFunctionName(i)] = new TabulatedFunctionWrapper(tabulatedFunctions, i);
}
// Create the expressions.
ParsedExpression energyExpr = Parser::parse(force.getEnergyFunction(), functions);
energyExpression = energyExpr.createProgram();
ParsedExpression energyExpr = Parser::parse(force.getEnergyFunction(), functions).optimize();
energyExpression = energyExpr.createCompiledExpression();
variableDerivExpressions.clear();
for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
variableDerivExpressions.push_back(energyExpr.differentiate(name).createCompiledExpression());
paramDerivExpressions.clear();
for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
paramDerivExpressions.push_back(energyExpr.differentiate(name).createCompiledExpression());
globalValues.resize(variableNames.size());
cvValues.resize(numCVs);
map<string, double*> variableLocations;
for (int i = 0; i < globalParameterNames.size(); i++)
variableLocations[globalParameterNames[i]] = &globalValues[i];
for (int i = 0; i < numCVs; i++)
variableLocations[variableNames[i]] = &cvValues[i];
energyExpression.setVariableLocations(variableLocations);
for (CompiledExpression& expr : variableDerivExpressions)
expr.setVariableLocations(variableLocations);
for (CompiledExpression& expr : paramDerivExpressions)
expr.setVariableLocations(variableLocations);
// Delete the custom functions.
......@@ -63,78 +105,63 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
delete function.second;
}
static void replaceFunctionsInExpression(map<string, CustomFunction*>& functions, ExpressionProgram& expression) {
for (int i = 0; i < expression.getNumOperations(); i++) {
if (expression.getOperation(i).getId() == Operation::CUSTOM) {
const Operation::Custom& op = dynamic_cast<const Operation::Custom&>(expression.getOperation(i));
expression.setOperation(i, new Operation::Custom(op.getName(), functions[op.getName()]->clone(), op.getDerivOrder()));
}
}
}
void ReferenceCustomCVForce::updateTabulatedFunctions(const OpenMM::CustomCVForce& force) {
// Create custom functions for the tabulated functions.
map<string, CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Replace tabulated functions in the expressions.
replaceFunctionsInExpression(functions, energyExpression);
for (auto& expression : variableDerivExpressions)
replaceFunctionsInExpression(functions, expression);
for (auto& expression : paramDerivExpressions)
replaceFunctionsInExpression(functions, expression);
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
if (tabulatedFunctions[i] != NULL) {
delete tabulatedFunctions[i];
tabulatedFunctions[i] = NULL;
}
tabulatedFunctions[i] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
}
}
ReferenceCustomCVForce::~ReferenceCustomCVForce() {
for (int i = 0; i < tabulatedFunctions.size(); i++)
if (tabulatedFunctions[i] != NULL)
delete tabulatedFunctions[i];
}
void ReferenceCustomCVForce::calculateIxn(ContextImpl& innerContext, vector<Vec3>& atomCoordinates,
const map<string, double>& globalParameters, vector<Vec3>& forces,
double* totalEnergy, map<string, double>& energyParamDerivs) const {
double* totalEnergy, map<string, double>& energyParamDerivs) {
// Compute the collective variables, and their derivatives with respect to particle positions.
int numCVs = variableNames.size();
ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(innerContext.getPlatformData());
vector<Vec3>& innerForces = *((vector<Vec3>*) data->forces);
map<string, double>& innerDerivs = *((map<string, double>*) data->energyParameterDerivatives);
vector<double> cvValues;
vector<vector<Vec3> > cvForces;
vector<map<string, double> > cvDerivs;
for (int i = 0; i < numCVs; i++) {
cvValues.push_back(innerContext.calcForcesAndEnergy(true, true, 1<<i));
cvValues[i] = innerContext.calcForcesAndEnergy(true, true, 1<<i);
cvForces.push_back(innerForces);
cvDerivs.push_back(innerDerivs);
}
// Compute the energy and forces.
for (int i = 0; i < globalParameterNames.size(); i++)
globalValues[i] = globalParameters.at(globalParameterNames[i]);
int numParticles = atomCoordinates.size();
map<string, double> variables = globalParameters;
for (int i = 0; i < numCVs; i++)
variables[variableNames[i]] = cvValues[i];
if (totalEnergy != NULL)
*totalEnergy += energyExpression.evaluate(variables);
*totalEnergy += energyExpression.evaluate();
for (int i = 0; i < numCVs; i++) {
double dEdV = variableDerivExpressions[i].evaluate(variables);
double dEdV = variableDerivExpressions[i].evaluate();
for (int j = 0; j < numParticles; j++)
forces[j] += cvForces[i][j]*dEdV;
}
// Compute the energy parameter derivatives.
if (paramDerivExpressions.size() > 0) {
for (int i = 0; i < paramDerivExpressions.size(); i++)
energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate(variables);
energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate();
for (int i = 0; i < numCVs; i++) {
double dEdV = variableDerivExpressions[i].evaluate(variables);
double dEdV = variableDerivExpressions[i].evaluate();
for (auto& deriv : cvDerivs[i])
energyParamDerivs[deriv.first] += dEdV*deriv.second;
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment