Unverified Commit 17b61225 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Use CompiledExpression for CustomCVForce energy expression (#3898)

parent b57a5a63
...@@ -909,6 +909,7 @@ public: ...@@ -909,6 +909,7 @@ public:
CommonCalcCustomCVForceKernel(std::string name, const Platform& platform, ComputeContext& cc) : CalcCustomCVForceKernel(name, platform), CommonCalcCustomCVForceKernel(std::string name, const Platform& platform, ComputeContext& cc) : CalcCustomCVForceKernel(name, platform),
cc(cc), hasInitializedListeners(false) { cc(cc), hasInitializedListeners(false) {
} }
~CommonCalcCustomCVForceKernel();
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
...@@ -948,13 +949,16 @@ public: ...@@ -948,13 +949,16 @@ public:
private: private:
class ForceInfo; class ForceInfo;
class ReorderListener; class ReorderListener;
class TabulatedFunctionWrapper;
ComputeContext& cc; ComputeContext& cc;
bool hasInitializedListeners; bool hasInitializedListeners;
Lepton::ExpressionProgram energyExpression; Lepton::CompiledExpression energyExpression;
std::vector<std::string> variableNames, paramDerivNames, globalParameterNames; std::vector<std::string> variableNames, paramDerivNames, globalParameterNames;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions; std::vector<Lepton::CompiledExpression> variableDerivExpressions;
std::vector<Lepton::ExpressionProgram> paramDerivExpressions; std::vector<Lepton::CompiledExpression> paramDerivExpressions;
std::vector<ComputeArray> cvForces; std::vector<ComputeArray> cvForces;
std::vector<double> globalValues, cvValues;
std::vector<Lepton::CustomFunction*> tabulatedFunctions;
ComputeArray invAtomOrder; ComputeArray invAtomOrder;
ComputeArray innerInvAtomOrder; ComputeArray innerInvAtomOrder;
ComputeKernel copyStateKernel, copyForcesKernel, addForcesKernel; ComputeKernel copyStateKernel, copyForcesKernel, addForcesKernel;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2022 Stanford University and the Authors. * * Portions copyright (c) 2008-2023 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -98,15 +98,6 @@ static pair<ExpressionTreeNode, string> makeVariable(const string& name, const s ...@@ -98,15 +98,6 @@ static pair<ExpressionTreeNode, string> makeVariable(const string& name, const s
return make_pair(ExpressionTreeNode(new Operation::Variable(name)), value); return make_pair(ExpressionTreeNode(new Operation::Variable(name)), value);
} }
static void replaceFunctionsInExpression(map<string, CustomFunction*>& functions, ExpressionProgram& expression) {
for (int i = 0; i < expression.getNumOperations(); i++) {
if (expression.getOperation(i).getId() == Operation::CUSTOM) {
const Operation::Custom& op = dynamic_cast<const Operation::Custom&>(expression.getOperation(i));
expression.setOperation(i, new Operation::Custom(op.getName(), functions[op.getName()]->clone(), op.getDerivOrder()));
}
}
}
void CommonApplyConstraintsKernel::initialize(const System& system) { void CommonApplyConstraintsKernel::initialize(const System& system) {
} }
...@@ -5136,6 +5127,30 @@ private: ...@@ -5136,6 +5127,30 @@ private:
ArrayInterface& invAtomOrder; ArrayInterface& invAtomOrder;
}; };
// This class allows us to update tabulated functions without having to recompile expressions
// that use them.
class CommonCalcCustomCVForceKernel::TabulatedFunctionWrapper : public CustomFunction {
public:
TabulatedFunctionWrapper(vector<Lepton::CustomFunction*>& tabulatedFunctions, int index) :
tabulatedFunctions(tabulatedFunctions), index(index) {
}
int getNumArguments() const {
return tabulatedFunctions[index]->getNumArguments();
}
double evaluate(const double* arguments) const {
return tabulatedFunctions[index]->evaluate(arguments);
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return tabulatedFunctions[index]->evaluateDerivative(arguments, derivOrder);
}
CustomFunction* clone() const {
return new TabulatedFunctionWrapper(tabulatedFunctions, index);
}
private:
vector<Lepton::CustomFunction*>& tabulatedFunctions;
int index;
};
void CommonCalcCustomCVForceKernel::initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext) { void CommonCalcCustomCVForceKernel::initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext) {
ContextSelector selector(cc); ContextSelector selector(cc);
int numCVs = force.getNumCollectiveVariables(); int numCVs = force.getNumCollectiveVariables();
...@@ -5152,19 +5167,34 @@ void CommonCalcCustomCVForceKernel::initialize(const System& system, const Custo ...@@ -5152,19 +5167,34 @@ void CommonCalcCustomCVForceKernel::initialize(const System& system, const Custo
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++) tabulatedFunctions.resize(force.getNumTabulatedFunctions(), NULL);
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
tabulatedFunctions[i] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
functions[force.getTabulatedFunctionName(i)] = new TabulatedFunctionWrapper(tabulatedFunctions, i);
}
// Create the expressions. // Create the expressions.
Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions); Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
energyExpression = energyExpr.createProgram(); energyExpression = energyExpr.createCompiledExpression();
variableDerivExpressions.clear(); variableDerivExpressions.clear();
for (auto& name : variableNames) for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram()); variableDerivExpressions.push_back(energyExpr.differentiate(name).createCompiledExpression());
paramDerivExpressions.clear(); paramDerivExpressions.clear();
for (auto& name : paramDerivNames) for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram()); paramDerivExpressions.push_back(energyExpr.differentiate(name).createCompiledExpression());
globalValues.resize(globalParameterNames.size());
cvValues.resize(numCVs);
map<string, double*> variableLocations;
for (int i = 0; i < globalParameterNames.size(); i++)
variableLocations[globalParameterNames[i]] = &globalValues[i];
for (int i = 0; i < numCVs; i++)
variableLocations[variableNames[i]] = &cvValues[i];
energyExpression.setVariableLocations(variableLocations);
for (CompiledExpression& expr : variableDerivExpressions)
expr.setVariableLocations(variableLocations);
for (CompiledExpression& expr : paramDerivExpressions)
expr.setVariableLocations(variableLocations);
// Delete the custom functions. // Delete the custom functions.
...@@ -5229,15 +5259,20 @@ void CommonCalcCustomCVForceKernel::initialize(const System& system, const Custo ...@@ -5229,15 +5259,20 @@ void CommonCalcCustomCVForceKernel::initialize(const System& system, const Custo
cc.addForce(new ForceInfo(*info)); cc.addForce(new ForceInfo(*info));
} }
CommonCalcCustomCVForceKernel::~CommonCalcCustomCVForceKernel() {
for (int i = 0; i < tabulatedFunctions.size(); i++)
if (tabulatedFunctions[i] != NULL)
delete tabulatedFunctions[i];
}
double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl& innerContext, bool includeForces, bool includeEnergy) { double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl& innerContext, bool includeForces, bool includeEnergy) {
copyState(context, innerContext); copyState(context, innerContext);
int numCVs = variableNames.size(); int numCVs = variableNames.size();
int numAtoms = cc.getNumAtoms(); int numAtoms = cc.getNumAtoms();
int paddedNumAtoms = cc.getPaddedNumAtoms(); int paddedNumAtoms = cc.getPaddedNumAtoms();
vector<double> cvValues;
vector<map<string, double> > cvDerivs(numCVs); vector<map<string, double> > cvDerivs(numCVs);
for (int i = 0; i < numCVs; i++) { for (int i = 0; i < numCVs; i++) {
cvValues.push_back(innerContext.calcForcesAndEnergy(true, true, 1<<i)); cvValues[i] = innerContext.calcForcesAndEnergy(true, true, 1<<i);
ContextSelector selector(cc); ContextSelector selector(cc);
copyForcesKernel->setArg(0, cvForces[i]); copyForcesKernel->setArg(0, cvForces[i]);
copyForcesKernel->execute(numAtoms); copyForcesKernel->execute(numAtoms);
...@@ -5247,14 +5282,11 @@ double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl& ...@@ -5247,14 +5282,11 @@ double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl&
// Compute the energy and forces. // Compute the energy and forces.
ContextSelector selector(cc); ContextSelector selector(cc);
map<string, double> variables; for (int i = 0; i < globalParameterNames.size(); i++)
for (auto& name : globalParameterNames) globalValues[i] = context.getParameter(globalParameterNames[i]);
variables[name] = context.getParameter(name); double energy = energyExpression.evaluate();
for (int i = 0; i < numCVs; i++)
variables[variableNames[i]] = cvValues[i];
double energy = energyExpression.evaluate(variables);
for (int i = 0; i < numCVs; i++) { for (int i = 0; i < numCVs; i++) {
double dEdV = variableDerivExpressions[i].evaluate(variables); double dEdV = variableDerivExpressions[i].evaluate();
addForcesKernel->setArg(2*i+2, cvForces[i]); addForcesKernel->setArg(2*i+2, cvForces[i]);
if (cc.getUseDoublePrecision()) if (cc.getUseDoublePrecision())
addForcesKernel->setArg(2*i+3, dEdV); addForcesKernel->setArg(2*i+3, dEdV);
...@@ -5262,16 +5294,18 @@ double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl& ...@@ -5262,16 +5294,18 @@ double CommonCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl&
addForcesKernel->setArg(2*i+3, (float) dEdV); addForcesKernel->setArg(2*i+3, (float) dEdV);
} }
addForcesKernel->execute(numAtoms); addForcesKernel->execute(numAtoms);
// Compute the energy parameter derivatives. // Compute the energy parameter derivatives.
map<string, double>& energyParamDerivs = cc.getEnergyParamDerivWorkspace(); if (paramDerivExpressions.size() > 0) {
for (int i = 0; i < paramDerivExpressions.size(); i++) map<string, double>& energyParamDerivs = cc.getEnergyParamDerivWorkspace();
energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate(variables); for (int i = 0; i < paramDerivExpressions.size(); i++)
for (int i = 0; i < numCVs; i++) { energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate();
double dEdV = variableDerivExpressions[i].evaluate(variables); for (int i = 0; i < numCVs; i++) {
for (auto& deriv : cvDerivs[i]) double dEdV = variableDerivExpressions[i].evaluate();
energyParamDerivs[deriv.first] += dEdV*deriv.second; for (auto& deriv : cvDerivs[i])
energyParamDerivs[deriv.first] += dEdV*deriv.second;
}
} }
return energy; return energy;
} }
...@@ -5305,22 +5339,13 @@ void CommonCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl& ...@@ -5305,22 +5339,13 @@ void CommonCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl&
void CommonCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) { void CommonCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) {
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, CustomFunction*> functions; for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++) if (tabulatedFunctions[i] != NULL) {
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); delete tabulatedFunctions[i];
tabulatedFunctions[i] = NULL;
// Replace tabulated functions in the expressions. }
tabulatedFunctions[i] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
replaceFunctionsInExpression(functions, energyExpression); }
for (auto& expression : variableDerivExpressions)
replaceFunctionsInExpression(functions, expression);
for (auto& expression : paramDerivExpressions)
replaceFunctionsInExpression(functions, expression);
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
} }
void CommonIntegrateVerletStepKernel::initialize(const System& system, const VerletIntegrator& integrator) { void CommonIntegrateVerletStepKernel::initialize(const System& system, const VerletIntegrator& integrator) {
......
/* Portions copyright (c) 2017 Stanford University and Simbios. /* Portions copyright (c) 2017-2023 Stanford University and Simbios.
* Contributors: Peter Eastman * Contributors: Peter Eastman
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
...@@ -27,7 +27,8 @@ ...@@ -27,7 +27,8 @@
#include "openmm/CustomCVForce.h" #include "openmm/CustomCVForce.h"
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "lepton/ExpressionProgram.h" #include "lepton/CompiledExpression.h"
#include "lepton/CustomFunction.h"
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -36,10 +37,13 @@ namespace OpenMM { ...@@ -36,10 +37,13 @@ namespace OpenMM {
class ReferenceCustomCVForce { class ReferenceCustomCVForce {
private: private:
Lepton::ExpressionProgram energyExpression; class TabulatedFunctionWrapper;
std::vector<std::string> variableNames, paramDerivNames; Lepton::CompiledExpression energyExpression;
std::vector<Lepton::ExpressionProgram> variableDerivExpressions; std::vector<std::string> variableNames, paramDerivNames, globalParameterNames;
std::vector<Lepton::ExpressionProgram> paramDerivExpressions; std::vector<Lepton::CompiledExpression> variableDerivExpressions;
std::vector<Lepton::CompiledExpression> paramDerivExpressions;
std::vector<double> globalValues, cvValues;
std::vector<Lepton::CustomFunction*> tabulatedFunctions;
public: public:
/** /**
...@@ -70,7 +74,7 @@ public: ...@@ -70,7 +74,7 @@ public:
*/ */
void calculateIxn(ContextImpl& innerContext, std::vector<OpenMM::Vec3>& atomCoordinates, void calculateIxn(ContextImpl& innerContext, std::vector<OpenMM::Vec3>& atomCoordinates,
const std::map<std::string, double>& globalParameters, const std::map<std::string, double>& globalParameters,
std::vector<OpenMM::Vec3>& forces, double* totalEnergy, std::map<std::string, double>& energyParamDerivs) const; std::vector<OpenMM::Vec3>& forces, double* totalEnergy, std::map<std::string, double>& energyParamDerivs);
}; };
} // namespace OpenMM } // namespace OpenMM
......
/* Portions copyright (c) 2009-2017 Stanford University and Simbios. /* Portions copyright (c) 2009-2023 Stanford University and Simbios.
* Contributors: Peter Eastman * Contributors: Peter Eastman
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
...@@ -34,8 +34,35 @@ using namespace OpenMM; ...@@ -34,8 +34,35 @@ using namespace OpenMM;
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
// This class allows us to update tabulated functions without having to recompile expressions
// that use them.
class ReferenceCustomCVForce::TabulatedFunctionWrapper : public CustomFunction {
public:
TabulatedFunctionWrapper(vector<Lepton::CustomFunction*>& tabulatedFunctions, int index) :
tabulatedFunctions(tabulatedFunctions), index(index) {
}
int getNumArguments() const {
return tabulatedFunctions[index]->getNumArguments();
}
double evaluate(const double* arguments) const {
return tabulatedFunctions[index]->evaluate(arguments);
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return tabulatedFunctions[index]->evaluateDerivative(arguments, derivOrder);
}
CustomFunction* clone() const {
return new TabulatedFunctionWrapper(tabulatedFunctions, index);
}
private:
vector<Lepton::CustomFunction*>& tabulatedFunctions;
int index;
};
ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) { ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
for (int i = 0; i < force.getNumCollectiveVariables(); i++) int numCVs = force.getNumCollectiveVariables();
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
for (int i = 0; i < numCVs; i++)
variableNames.push_back(force.getCollectiveVariableName(i)); variableNames.push_back(force.getCollectiveVariableName(i));
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
paramDerivNames.push_back(force.getEnergyParameterDerivativeName(i)); paramDerivNames.push_back(force.getEnergyParameterDerivativeName(i));
...@@ -43,19 +70,34 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) { ...@@ -43,19 +70,34 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, CustomFunction*> functions; map<string, CustomFunction*> functions;
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++) tabulatedFunctions.resize(force.getNumTabulatedFunctions(), NULL);
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
tabulatedFunctions[i] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
functions[force.getTabulatedFunctionName(i)] = new TabulatedFunctionWrapper(tabulatedFunctions, i);
}
// Create the expressions. // Create the expressions.
ParsedExpression energyExpr = Parser::parse(force.getEnergyFunction(), functions); ParsedExpression energyExpr = Parser::parse(force.getEnergyFunction(), functions).optimize();
energyExpression = energyExpr.createProgram(); energyExpression = energyExpr.createCompiledExpression();
variableDerivExpressions.clear(); variableDerivExpressions.clear();
for (auto& name : variableNames) for (auto& name : variableNames)
variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram()); variableDerivExpressions.push_back(energyExpr.differentiate(name).createCompiledExpression());
paramDerivExpressions.clear(); paramDerivExpressions.clear();
for (auto& name : paramDerivNames) for (auto& name : paramDerivNames)
paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram()); paramDerivExpressions.push_back(energyExpr.differentiate(name).createCompiledExpression());
globalValues.resize(variableNames.size());
cvValues.resize(numCVs);
map<string, double*> variableLocations;
for (int i = 0; i < globalParameterNames.size(); i++)
variableLocations[globalParameterNames[i]] = &globalValues[i];
for (int i = 0; i < numCVs; i++)
variableLocations[variableNames[i]] = &cvValues[i];
energyExpression.setVariableLocations(variableLocations);
for (CompiledExpression& expr : variableDerivExpressions)
expr.setVariableLocations(variableLocations);
for (CompiledExpression& expr : paramDerivExpressions)
expr.setVariableLocations(variableLocations);
// Delete the custom functions. // Delete the custom functions.
...@@ -63,78 +105,63 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) { ...@@ -63,78 +105,63 @@ ReferenceCustomCVForce::ReferenceCustomCVForce(const CustomCVForce& force) {
delete function.second; delete function.second;
} }
static void replaceFunctionsInExpression(map<string, CustomFunction*>& functions, ExpressionProgram& expression) {
for (int i = 0; i < expression.getNumOperations(); i++) {
if (expression.getOperation(i).getId() == Operation::CUSTOM) {
const Operation::Custom& op = dynamic_cast<const Operation::Custom&>(expression.getOperation(i));
expression.setOperation(i, new Operation::Custom(op.getName(), functions[op.getName()]->clone(), op.getDerivOrder()));
}
}
}
void ReferenceCustomCVForce::updateTabulatedFunctions(const OpenMM::CustomCVForce& force) { void ReferenceCustomCVForce::updateTabulatedFunctions(const OpenMM::CustomCVForce& force) {
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, CustomFunction*> functions; for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++) if (tabulatedFunctions[i] != NULL) {
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); delete tabulatedFunctions[i];
tabulatedFunctions[i] = NULL;
// Replace tabulated functions in the expressions. }
tabulatedFunctions[i] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
replaceFunctionsInExpression(functions, energyExpression); }
for (auto& expression : variableDerivExpressions)
replaceFunctionsInExpression(functions, expression);
for (auto& expression : paramDerivExpressions)
replaceFunctionsInExpression(functions, expression);
// Delete the custom functions.
for (auto& function : functions)
delete function.second;
} }
ReferenceCustomCVForce::~ReferenceCustomCVForce() { ReferenceCustomCVForce::~ReferenceCustomCVForce() {
for (int i = 0; i < tabulatedFunctions.size(); i++)
if (tabulatedFunctions[i] != NULL)
delete tabulatedFunctions[i];
} }
void ReferenceCustomCVForce::calculateIxn(ContextImpl& innerContext, vector<Vec3>& atomCoordinates, void ReferenceCustomCVForce::calculateIxn(ContextImpl& innerContext, vector<Vec3>& atomCoordinates,
const map<string, double>& globalParameters, vector<Vec3>& forces, const map<string, double>& globalParameters, vector<Vec3>& forces,
double* totalEnergy, map<string, double>& energyParamDerivs) const { double* totalEnergy, map<string, double>& energyParamDerivs) {
// Compute the collective variables, and their derivatives with respect to particle positions. // Compute the collective variables, and their derivatives with respect to particle positions.
int numCVs = variableNames.size(); int numCVs = variableNames.size();
ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(innerContext.getPlatformData()); ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(innerContext.getPlatformData());
vector<Vec3>& innerForces = *((vector<Vec3>*) data->forces); vector<Vec3>& innerForces = *((vector<Vec3>*) data->forces);
map<string, double>& innerDerivs = *((map<string, double>*) data->energyParameterDerivatives); map<string, double>& innerDerivs = *((map<string, double>*) data->energyParameterDerivatives);
vector<double> cvValues;
vector<vector<Vec3> > cvForces; vector<vector<Vec3> > cvForces;
vector<map<string, double> > cvDerivs; vector<map<string, double> > cvDerivs;
for (int i = 0; i < numCVs; i++) { for (int i = 0; i < numCVs; i++) {
cvValues.push_back(innerContext.calcForcesAndEnergy(true, true, 1<<i)); cvValues[i] = innerContext.calcForcesAndEnergy(true, true, 1<<i);
cvForces.push_back(innerForces); cvForces.push_back(innerForces);
cvDerivs.push_back(innerDerivs); cvDerivs.push_back(innerDerivs);
} }
// Compute the energy and forces. // Compute the energy and forces.
for (int i = 0; i < globalParameterNames.size(); i++)
globalValues[i] = globalParameters.at(globalParameterNames[i]);
int numParticles = atomCoordinates.size(); int numParticles = atomCoordinates.size();
map<string, double> variables = globalParameters;
for (int i = 0; i < numCVs; i++)
variables[variableNames[i]] = cvValues[i];
if (totalEnergy != NULL) if (totalEnergy != NULL)
*totalEnergy += energyExpression.evaluate(variables); *totalEnergy += energyExpression.evaluate();
for (int i = 0; i < numCVs; i++) { for (int i = 0; i < numCVs; i++) {
double dEdV = variableDerivExpressions[i].evaluate(variables); double dEdV = variableDerivExpressions[i].evaluate();
for (int j = 0; j < numParticles; j++) for (int j = 0; j < numParticles; j++)
forces[j] += cvForces[i][j]*dEdV; forces[j] += cvForces[i][j]*dEdV;
} }
// Compute the energy parameter derivatives. // Compute the energy parameter derivatives.
for (int i = 0; i < paramDerivExpressions.size(); i++) if (paramDerivExpressions.size() > 0) {
energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate(variables); for (int i = 0; i < paramDerivExpressions.size(); i++)
for (int i = 0; i < numCVs; i++) { energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate();
double dEdV = variableDerivExpressions[i].evaluate(variables); for (int i = 0; i < numCVs; i++) {
for (auto& deriv : cvDerivs[i]) double dEdV = variableDerivExpressions[i].evaluate();
energyParamDerivs[deriv.first] += dEdV*deriv.second; for (auto& deriv : cvDerivs[i])
energyParamDerivs[deriv.first] += dEdV*deriv.second;
}
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment