Unverified Commit 8292bb3a authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Reduced the cost of updating tabulated functions (#3649)

parent d7da750a
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014 Stanford University and the Authors. *
* Portions copyright (c) 2014-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -57,6 +57,8 @@ namespace OpenMM {
class OPENMM_EXPORT TabulatedFunction {
public:
TabulatedFunction() : updateCount(0) {
}
virtual ~TabulatedFunction() {
}
/**
......@@ -67,12 +69,18 @@ public:
* Get the periodicity status of the tabulated function.
*/
bool getPeriodic() const;
/**
* Get the value of a counter that is updated every time setFunctionParameters()
* is called. This provides a fast way to detect when a function has changed.
*/
int getUpdateCount() const;
virtual bool operator==(const TabulatedFunction& other) const = 0;
virtual bool operator!=(const TabulatedFunction& other) const {
return !(*this == other);
}
protected:
bool periodic;
int updateCount;
};
/**
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014-2021 Stanford University and the Authors. *
* Portions copyright (c) 2014-2022 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -39,6 +39,10 @@ bool TabulatedFunction::getPeriodic() const {
return periodic;
}
int TabulatedFunction::getUpdateCount() const {
return updateCount;
}
Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max, bool periodic) {
this->periodic = periodic;
setFunctionParameters(values, min, max);
......@@ -66,6 +70,7 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d
this->values = values;
this->min = min;
this->max = max;
updateCount++;
}
Continuous1DFunction* Continuous1DFunction::Copy() const {
......@@ -120,6 +125,7 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
updateCount++;
}
Continuous2DFunction* Continuous2DFunction::Copy() const {
......@@ -186,6 +192,7 @@ void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize
this->ymax = ymax;
this->zmin = zmin;
this->zmax = zmax;
updateCount++;
}
Continuous3DFunction* Continuous3DFunction::Copy() const {
......@@ -220,6 +227,7 @@ void Discrete1DFunction::getFunctionParameters(vector<double>& values) const {
void Discrete1DFunction::setFunctionParameters(const vector<double>& values) {
this->values = values;
updateCount++;
}
Discrete1DFunction* Discrete1DFunction::Copy() const {
......@@ -256,6 +264,7 @@ void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vecto
this->xsize = xsize;
this->ysize = ysize;
this->values = values;
updateCount++;
}
Discrete2DFunction* Discrete2DFunction::Copy() const {
......@@ -297,6 +306,7 @@ void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize,
this->ysize = ysize;
this->zsize = zsize;
this->values = values;
updateCount++;
}
Discrete3DFunction* Discrete3DFunction::Copy() const {
......
......@@ -530,7 +530,7 @@ private:
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
const System& system;
};
......@@ -579,7 +579,7 @@ private:
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
std::vector<void*> groupForcesArgs;
ComputeKernel computeCentersKernel, groupForcesKernel, applyForcesKernel;
const System& system;
......@@ -632,7 +632,7 @@ private:
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
std::vector<std::string> paramNames, computedValueNames;
std::vector<ComputeParameterInfo> paramBuffers, computedValueBuffers;
double longRangeCoefficient;
......@@ -735,7 +735,7 @@ private:
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
std::vector<bool> pairValueUsesParam, pairEnergyUsesParam, pairEnergyUsesValue;
const System& system;
ComputeKernel pairValueKernel, perParticleValueKernel, pairEnergyKernel, perParticleEnergyKernel, gradientChainRuleKernel;
......@@ -793,7 +793,7 @@ private:
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
const System& system;
ComputeKernel donorKernel, acceptorKernel;
};
......@@ -845,7 +845,7 @@ private:
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
const System& system;
ComputeKernel forceKernel, blockBoundsKernel, neighborsKernel, startIndicesKernel, copyPairsKernel;
ComputeEvent event;
......
......@@ -35,7 +35,6 @@
#include "openmm/internal/CustomCompoundBondForceImpl.h"
#include "openmm/internal/CustomHbondForceImpl.h"
#include "openmm/internal/CustomManyParticleForceImpl.h"
#include "openmm/serialization/XmlSerializer.h"
#include "CommonKernelSources.h"
#include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h"
......@@ -1294,7 +1293,7 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
......@@ -1417,8 +1416,8 @@ void CommonCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImp
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
......@@ -1553,7 +1552,7 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
......@@ -1747,8 +1746,8 @@ void CommonCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImp
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
......@@ -1902,7 +1901,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = prefix+"table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
......@@ -2459,8 +2458,8 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
......@@ -2798,7 +2797,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = prefix+"table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
......@@ -3785,8 +3784,8 @@ void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
......@@ -4012,7 +4011,7 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
......@@ -4336,8 +4335,8 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
......@@ -4425,7 +4424,7 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
......@@ -4855,8 +4854,8 @@ void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImp
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
......
......@@ -334,7 +334,7 @@ private:
std::vector<std::string> parameterNames, globalParameterNames, computedValueNames, energyParamDerivNames;
std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups;
std::vector<double> longRangeCoefficientDerivs;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
NonbondedMethod nonbondedMethod;
CpuCustomNonbondedForce* nonbonded;
};
......@@ -424,7 +424,7 @@ private:
std::vector<std::string> particleParameterNames, globalParameterNames, energyParamDerivNames, valueNames;
std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes;
std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
NonbondedMethod nonbondedMethod;
};
......@@ -467,7 +467,7 @@ private:
std::vector<std::vector<double> > particleParamArray;
CpuCustomManyParticleForce* ixn;
std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
NonbondedMethod nonbondedMethod;
};
......
......@@ -45,7 +45,6 @@
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/vectorize.h"
#include "openmm/serialization/XmlSerializer.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CustomFunction.h"
#include "lepton/Operation.h"
......@@ -888,10 +887,10 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
switchingDistance = force.getSwitchingDistance();
}
// Record the tabulated functions for future reference.
// Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Record information for the long range correction.
......@@ -1053,8 +1052,8 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true;
}
}
......@@ -1166,10 +1165,10 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
neighborList = new CpuNeighborList(4);
data.isPeriodic |= (force.getNonbondedMethod() == CustomGBForce::CutoffPeriodic);
// Record the tabulated functions for future reference.
// Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the interaction.
......@@ -1319,8 +1318,8 @@ void CpuCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, c
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true;
}
}
......@@ -1349,10 +1348,10 @@ void CpuCalcCustomManyParticleForceKernel::initialize(const System& system, cons
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
// Record the tabulated functions for future reference.
// Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the interaction.
......@@ -1399,8 +1398,8 @@ void CpuCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl&
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true;
}
}
......
......@@ -706,7 +706,7 @@ private:
std::vector<std::string> parameterNames, globalParameterNames, computedValueNames, energyParamDerivNames;
std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups;
std::vector<double> longRangeCoefficientDerivs;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
NonbondedMethod nonbondedMethod;
NeighborList* neighborList;
};
......@@ -797,7 +797,7 @@ private:
std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyParamDerivExpressions;
std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
NonbondedMethod nonbondedMethod;
NeighborList* neighborList;
};
......@@ -884,7 +884,7 @@ private:
ReferenceCustomHbondIxn* ixn;
std::vector<std::set<int> > exclusions;
std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
};
/**
......@@ -927,7 +927,7 @@ private:
std::vector<std::vector<double> > bondParamArray;
ReferenceCustomCentroidBondIxn* ixn;
std::vector<std::string> globalParameterNames, energyParamDerivNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
bool usePeriodic;
Vec3* boxVectors;
};
......@@ -970,7 +970,7 @@ private:
std::vector<std::vector<double> > bondParamArray;
ReferenceCustomCompoundBondIxn* ixn;
std::vector<std::string> globalParameterNames, energyParamDerivNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
bool usePeriodic;
Vec3* boxVectors;
};
......@@ -1012,7 +1012,7 @@ private:
std::vector<std::vector<double> > particleParamArray;
ReferenceCustomManyParticleIxn* ixn;
std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::map<std::string, int> tabulatedFunctionUpdateCount;
NonbondedMethod nonbondedMethod;
};
......
......@@ -80,7 +80,6 @@
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/Integrator.h"
#include "openmm/OpenMMException.h"
#include "openmm/serialization/XmlSerializer.h"
#include "SimTKOpenMMUtilities.h"
#include "lepton/CustomFunction.h"
#include "lepton/Operation.h"
......@@ -1177,10 +1176,10 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
switchingDistance = force.getSwitchingDistance();
}
// Record the tabulated functions for future reference.
// Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the expressions.
......@@ -1349,8 +1348,8 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true;
}
}
......@@ -1465,10 +1464,10 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
else
neighborList = new NeighborList();
// Record the tabulated functions for future reference.
// Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the expressions.
......@@ -1624,8 +1623,8 @@ void ReferenceCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& cont
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true;
}
}
......@@ -1750,10 +1749,10 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
globalParameterNames.push_back(force.getGlobalParameterName(i));
nonbondedCutoff = force.getCutoffDistance();
// Record the tabulated functions for future reference.
// Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the interaction.
......@@ -1839,8 +1838,8 @@ void ReferenceCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& c
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true;
}
}
......@@ -1873,10 +1872,10 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system
for (int i = 0; i < numBonds; ++i)
force.getBondParameters(i, bondGroups[i], bondParamArray[i]);
// Record the tabulated functions for future reference.
// Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the interaction.
......@@ -1962,8 +1961,8 @@ void ReferenceCalcCustomCentroidBondForceKernel::copyParametersToContext(Context
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true;
}
}
......@@ -1990,10 +1989,10 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system
for (int i = 0; i < numBonds; ++i)
force.getBondParameters(i, bondParticles[i], bondParamArray[i]);
// Record the tabulated functions for future reference.
// Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the interaction.
......@@ -2078,8 +2077,8 @@ void ReferenceCalcCustomCompoundBondForceKernel::copyParametersToContext(Context
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true;
}
}
......@@ -2107,10 +2106,10 @@ void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
// Record the tabulated functions for future reference.
// Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the interaction.
......@@ -2158,8 +2157,8 @@ void ReferenceCalcCustomManyParticleForceKernel::copyParametersToContext(Context
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true;
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment