Unverified Commit 8292bb3a authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Reduced the cost of updating tabulated functions (#3649)

parent d7da750a
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2014 Stanford University and the Authors. * * Portions copyright (c) 2014-2022 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -57,6 +57,8 @@ namespace OpenMM { ...@@ -57,6 +57,8 @@ namespace OpenMM {
class OPENMM_EXPORT TabulatedFunction { class OPENMM_EXPORT TabulatedFunction {
public: public:
TabulatedFunction() : updateCount(0) {
}
virtual ~TabulatedFunction() { virtual ~TabulatedFunction() {
} }
/** /**
...@@ -67,12 +69,18 @@ public: ...@@ -67,12 +69,18 @@ public:
* Get the periodicity status of the tabulated function. * Get the periodicity status of the tabulated function.
*/ */
bool getPeriodic() const; bool getPeriodic() const;
/**
* Get the value of a counter that is updated every time setFunctionParameters()
* is called. This provides a fast way to detect when a function has changed.
*/
int getUpdateCount() const;
virtual bool operator==(const TabulatedFunction& other) const = 0; virtual bool operator==(const TabulatedFunction& other) const = 0;
virtual bool operator!=(const TabulatedFunction& other) const { virtual bool operator!=(const TabulatedFunction& other) const {
return !(*this == other); return !(*this == other);
} }
protected: protected:
bool periodic; bool periodic;
int updateCount;
}; };
/** /**
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2014-2021 Stanford University and the Authors. * * Portions copyright (c) 2014-2022 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -39,6 +39,10 @@ bool TabulatedFunction::getPeriodic() const { ...@@ -39,6 +39,10 @@ bool TabulatedFunction::getPeriodic() const {
return periodic; return periodic;
} }
int TabulatedFunction::getUpdateCount() const {
return updateCount;
}
Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max, bool periodic) { Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max, bool periodic) {
this->periodic = periodic; this->periodic = periodic;
setFunctionParameters(values, min, max); setFunctionParameters(values, min, max);
...@@ -66,6 +70,7 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d ...@@ -66,6 +70,7 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d
this->values = values; this->values = values;
this->min = min; this->min = min;
this->max = max; this->max = max;
updateCount++;
} }
Continuous1DFunction* Continuous1DFunction::Copy() const { Continuous1DFunction* Continuous1DFunction::Copy() const {
...@@ -120,6 +125,7 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec ...@@ -120,6 +125,7 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec
this->xmax = xmax; this->xmax = xmax;
this->ymin = ymin; this->ymin = ymin;
this->ymax = ymax; this->ymax = ymax;
updateCount++;
} }
Continuous2DFunction* Continuous2DFunction::Copy() const { Continuous2DFunction* Continuous2DFunction::Copy() const {
...@@ -186,6 +192,7 @@ void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize ...@@ -186,6 +192,7 @@ void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize
this->ymax = ymax; this->ymax = ymax;
this->zmin = zmin; this->zmin = zmin;
this->zmax = zmax; this->zmax = zmax;
updateCount++;
} }
Continuous3DFunction* Continuous3DFunction::Copy() const { Continuous3DFunction* Continuous3DFunction::Copy() const {
...@@ -220,6 +227,7 @@ void Discrete1DFunction::getFunctionParameters(vector<double>& values) const { ...@@ -220,6 +227,7 @@ void Discrete1DFunction::getFunctionParameters(vector<double>& values) const {
void Discrete1DFunction::setFunctionParameters(const vector<double>& values) { void Discrete1DFunction::setFunctionParameters(const vector<double>& values) {
this->values = values; this->values = values;
updateCount++;
} }
Discrete1DFunction* Discrete1DFunction::Copy() const { Discrete1DFunction* Discrete1DFunction::Copy() const {
...@@ -256,6 +264,7 @@ void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vecto ...@@ -256,6 +264,7 @@ void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const vecto
this->xsize = xsize; this->xsize = xsize;
this->ysize = ysize; this->ysize = ysize;
this->values = values; this->values = values;
updateCount++;
} }
Discrete2DFunction* Discrete2DFunction::Copy() const { Discrete2DFunction* Discrete2DFunction::Copy() const {
...@@ -297,6 +306,7 @@ void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize, ...@@ -297,6 +306,7 @@ void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize,
this->ysize = ysize; this->ysize = ysize;
this->zsize = zsize; this->zsize = zsize;
this->values = values; this->values = values;
updateCount++;
} }
Discrete3DFunction* Discrete3DFunction::Copy() const { Discrete3DFunction* Discrete3DFunction::Copy() const {
......
...@@ -530,7 +530,7 @@ private: ...@@ -530,7 +530,7 @@ private:
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
const System& system; const System& system;
}; };
...@@ -579,7 +579,7 @@ private: ...@@ -579,7 +579,7 @@ private:
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
std::vector<void*> groupForcesArgs; std::vector<void*> groupForcesArgs;
ComputeKernel computeCentersKernel, groupForcesKernel, applyForcesKernel; ComputeKernel computeCentersKernel, groupForcesKernel, applyForcesKernel;
const System& system; const System& system;
...@@ -632,7 +632,7 @@ private: ...@@ -632,7 +632,7 @@ private:
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
std::vector<std::string> paramNames, computedValueNames; std::vector<std::string> paramNames, computedValueNames;
std::vector<ComputeParameterInfo> paramBuffers, computedValueBuffers; std::vector<ComputeParameterInfo> paramBuffers, computedValueBuffers;
double longRangeCoefficient; double longRangeCoefficient;
...@@ -735,7 +735,7 @@ private: ...@@ -735,7 +735,7 @@ private:
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
std::vector<bool> pairValueUsesParam, pairEnergyUsesParam, pairEnergyUsesValue; std::vector<bool> pairValueUsesParam, pairEnergyUsesParam, pairEnergyUsesValue;
const System& system; const System& system;
ComputeKernel pairValueKernel, perParticleValueKernel, pairEnergyKernel, perParticleEnergyKernel, gradientChainRuleKernel; ComputeKernel pairValueKernel, perParticleValueKernel, pairEnergyKernel, perParticleEnergyKernel, gradientChainRuleKernel;
...@@ -793,7 +793,7 @@ private: ...@@ -793,7 +793,7 @@ private:
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
const System& system; const System& system;
ComputeKernel donorKernel, acceptorKernel; ComputeKernel donorKernel, acceptorKernel;
}; };
...@@ -845,7 +845,7 @@ private: ...@@ -845,7 +845,7 @@ private:
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctionArrays; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
const System& system; const System& system;
ComputeKernel forceKernel, blockBoundsKernel, neighborsKernel, startIndicesKernel, copyPairsKernel; ComputeKernel forceKernel, blockBoundsKernel, neighborsKernel, startIndicesKernel, copyPairsKernel;
ComputeEvent event; ComputeEvent event;
......
...@@ -35,7 +35,6 @@ ...@@ -35,7 +35,6 @@
#include "openmm/internal/CustomCompoundBondForceImpl.h" #include "openmm/internal/CustomCompoundBondForceImpl.h"
#include "openmm/internal/CustomHbondForceImpl.h" #include "openmm/internal/CustomHbondForceImpl.h"
#include "openmm/internal/CustomManyParticleForceImpl.h" #include "openmm/internal/CustomManyParticleForceImpl.h"
#include "openmm/serialization/XmlSerializer.h"
#include "CommonKernelSources.h" #include "CommonKernelSources.h"
#include "lepton/CustomFunction.h" #include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h" #include "lepton/ExpressionTreeNode.h"
...@@ -1294,7 +1293,7 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -1294,7 +1293,7 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
...@@ -1417,8 +1416,8 @@ void CommonCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImp ...@@ -1417,8 +1416,8 @@ void CommonCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImp
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f); tabulatedFunctionArrays[i].upload(f);
...@@ -1553,7 +1552,7 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c ...@@ -1553,7 +1552,7 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = "table"+cc.intToString(i); string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
...@@ -1747,8 +1746,8 @@ void CommonCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImp ...@@ -1747,8 +1746,8 @@ void CommonCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImp
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f); tabulatedFunctionArrays[i].upload(f);
...@@ -1902,7 +1901,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -1902,7 +1901,7 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = prefix+"table"+cc.intToString(i); string arrayName = prefix+"table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
...@@ -2459,8 +2458,8 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& ...@@ -2459,8 +2458,8 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f); tabulatedFunctionArrays[i].upload(f);
...@@ -2798,7 +2797,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2798,7 +2797,7 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = prefix+"table"+cc.intToString(i); string arrayName = prefix+"table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
...@@ -3785,8 +3784,8 @@ void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context ...@@ -3785,8 +3784,8 @@ void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f); tabulatedFunctionArrays[i].upload(f);
...@@ -4012,7 +4011,7 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -4012,7 +4011,7 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = "table"+cc.intToString(i); string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
...@@ -4336,8 +4335,8 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont ...@@ -4336,8 +4335,8 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f); tabulatedFunctionArrays[i].upload(f);
...@@ -4425,7 +4424,7 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c ...@@ -4425,7 +4424,7 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
string arrayName = "table"+cc.intToString(i); string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
...@@ -4855,8 +4854,8 @@ void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImp ...@@ -4855,8 +4854,8 @@ void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImp
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f); tabulatedFunctionArrays[i].upload(f);
......
...@@ -334,7 +334,7 @@ private: ...@@ -334,7 +334,7 @@ private:
std::vector<std::string> parameterNames, globalParameterNames, computedValueNames, energyParamDerivNames; std::vector<std::string> parameterNames, globalParameterNames, computedValueNames, energyParamDerivNames;
std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups; std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups;
std::vector<double> longRangeCoefficientDerivs; std::vector<double> longRangeCoefficientDerivs;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
CpuCustomNonbondedForce* nonbonded; CpuCustomNonbondedForce* nonbonded;
}; };
...@@ -424,7 +424,7 @@ private: ...@@ -424,7 +424,7 @@ private:
std::vector<std::string> particleParameterNames, globalParameterNames, energyParamDerivNames, valueNames; std::vector<std::string> particleParameterNames, globalParameterNames, energyParamDerivNames, valueNames;
std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes; std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes;
std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes; std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
}; };
...@@ -467,7 +467,7 @@ private: ...@@ -467,7 +467,7 @@ private:
std::vector<std::vector<double> > particleParamArray; std::vector<std::vector<double> > particleParamArray;
CpuCustomManyParticleForce* ixn; CpuCustomManyParticleForce* ixn;
std::vector<std::string> globalParameterNames; std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
}; };
......
...@@ -45,7 +45,6 @@ ...@@ -45,7 +45,6 @@
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/internal/NonbondedForceImpl.h" #include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/vectorize.h" #include "openmm/internal/vectorize.h"
#include "openmm/serialization/XmlSerializer.h"
#include "lepton/CompiledExpression.h" #include "lepton/CompiledExpression.h"
#include "lepton/CustomFunction.h" #include "lepton/CustomFunction.h"
#include "lepton/Operation.h" #include "lepton/Operation.h"
...@@ -888,10 +887,10 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C ...@@ -888,10 +887,10 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
switchingDistance = force.getSwitchingDistance(); switchingDistance = force.getSwitchingDistance();
} }
// Record the tabulated functions for future reference. // Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Record information for the long range correction. // Record information for the long range correction.
...@@ -1053,8 +1052,8 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con ...@@ -1053,8 +1052,8 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con
bool changed = false; bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true; changed = true;
} }
} }
...@@ -1166,10 +1165,10 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB ...@@ -1166,10 +1165,10 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
neighborList = new CpuNeighborList(4); neighborList = new CpuNeighborList(4);
data.isPeriodic |= (force.getNonbondedMethod() == CustomGBForce::CutoffPeriodic); data.isPeriodic |= (force.getNonbondedMethod() == CustomGBForce::CutoffPeriodic);
// Record the tabulated functions for future reference. // Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the interaction. // Create the interaction.
...@@ -1319,8 +1318,8 @@ void CpuCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, c ...@@ -1319,8 +1318,8 @@ void CpuCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, c
bool changed = false; bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true; changed = true;
} }
} }
...@@ -1349,10 +1348,10 @@ void CpuCalcCustomManyParticleForceKernel::initialize(const System& system, cons ...@@ -1349,10 +1348,10 @@ void CpuCalcCustomManyParticleForceKernel::initialize(const System& system, cons
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
// Record the tabulated functions for future reference. // Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the interaction. // Create the interaction.
...@@ -1399,8 +1398,8 @@ void CpuCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl& ...@@ -1399,8 +1398,8 @@ void CpuCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl&
bool changed = false; bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true; changed = true;
} }
} }
......
...@@ -706,7 +706,7 @@ private: ...@@ -706,7 +706,7 @@ private:
std::vector<std::string> parameterNames, globalParameterNames, computedValueNames, energyParamDerivNames; std::vector<std::string> parameterNames, globalParameterNames, computedValueNames, energyParamDerivNames;
std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups; std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups;
std::vector<double> longRangeCoefficientDerivs; std::vector<double> longRangeCoefficientDerivs;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
NeighborList* neighborList; NeighborList* neighborList;
}; };
...@@ -797,7 +797,7 @@ private: ...@@ -797,7 +797,7 @@ private:
std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions; std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyParamDerivExpressions; std::vector<std::vector<Lepton::CompiledExpression> > energyParamDerivExpressions;
std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes; std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
NeighborList* neighborList; NeighborList* neighborList;
}; };
...@@ -884,7 +884,7 @@ private: ...@@ -884,7 +884,7 @@ private:
ReferenceCustomHbondIxn* ixn; ReferenceCustomHbondIxn* ixn;
std::vector<std::set<int> > exclusions; std::vector<std::set<int> > exclusions;
std::vector<std::string> globalParameterNames; std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
}; };
/** /**
...@@ -927,7 +927,7 @@ private: ...@@ -927,7 +927,7 @@ private:
std::vector<std::vector<double> > bondParamArray; std::vector<std::vector<double> > bondParamArray;
ReferenceCustomCentroidBondIxn* ixn; ReferenceCustomCentroidBondIxn* ixn;
std::vector<std::string> globalParameterNames, energyParamDerivNames; std::vector<std::string> globalParameterNames, energyParamDerivNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
bool usePeriodic; bool usePeriodic;
Vec3* boxVectors; Vec3* boxVectors;
}; };
...@@ -970,7 +970,7 @@ private: ...@@ -970,7 +970,7 @@ private:
std::vector<std::vector<double> > bondParamArray; std::vector<std::vector<double> > bondParamArray;
ReferenceCustomCompoundBondIxn* ixn; ReferenceCustomCompoundBondIxn* ixn;
std::vector<std::string> globalParameterNames, energyParamDerivNames; std::vector<std::string> globalParameterNames, energyParamDerivNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
bool usePeriodic; bool usePeriodic;
Vec3* boxVectors; Vec3* boxVectors;
}; };
...@@ -1012,7 +1012,7 @@ private: ...@@ -1012,7 +1012,7 @@ private:
std::vector<std::vector<double> > particleParamArray; std::vector<std::vector<double> > particleParamArray;
ReferenceCustomManyParticleIxn* ixn; ReferenceCustomManyParticleIxn* ixn;
std::vector<std::string> globalParameterNames; std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions; std::map<std::string, int> tabulatedFunctionUpdateCount;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
}; };
......
...@@ -80,7 +80,6 @@ ...@@ -80,7 +80,6 @@
#include "openmm/internal/NonbondedForceImpl.h" #include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/Integrator.h" #include "openmm/Integrator.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/serialization/XmlSerializer.h"
#include "SimTKOpenMMUtilities.h" #include "SimTKOpenMMUtilities.h"
#include "lepton/CustomFunction.h" #include "lepton/CustomFunction.h"
#include "lepton/Operation.h" #include "lepton/Operation.h"
...@@ -1177,10 +1176,10 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c ...@@ -1177,10 +1176,10 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
switchingDistance = force.getSwitchingDistance(); switchingDistance = force.getSwitchingDistance();
} }
// Record the tabulated functions for future reference. // Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the expressions. // Create the expressions.
...@@ -1349,8 +1348,8 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp ...@@ -1349,8 +1348,8 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
bool changed = false; bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true; changed = true;
} }
} }
...@@ -1465,10 +1464,10 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1465,10 +1464,10 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
else else
neighborList = new NeighborList(); neighborList = new NeighborList();
// Record the tabulated functions for future reference. // Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the expressions. // Create the expressions.
...@@ -1624,8 +1623,8 @@ void ReferenceCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& cont ...@@ -1624,8 +1623,8 @@ void ReferenceCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& cont
bool changed = false; bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true; changed = true;
} }
} }
...@@ -1750,10 +1749,10 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const ...@@ -1750,10 +1749,10 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
nonbondedCutoff = force.getCutoffDistance(); nonbondedCutoff = force.getCutoffDistance();
// Record the tabulated functions for future reference. // Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the interaction. // Create the interaction.
...@@ -1839,8 +1838,8 @@ void ReferenceCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& c ...@@ -1839,8 +1838,8 @@ void ReferenceCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& c
bool changed = false; bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true; changed = true;
} }
} }
...@@ -1873,10 +1872,10 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system ...@@ -1873,10 +1872,10 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system
for (int i = 0; i < numBonds; ++i) for (int i = 0; i < numBonds; ++i)
force.getBondParameters(i, bondGroups[i], bondParamArray[i]); force.getBondParameters(i, bondGroups[i], bondParamArray[i]);
// Record the tabulated functions for future reference. // Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the interaction. // Create the interaction.
...@@ -1962,8 +1961,8 @@ void ReferenceCalcCustomCentroidBondForceKernel::copyParametersToContext(Context ...@@ -1962,8 +1961,8 @@ void ReferenceCalcCustomCentroidBondForceKernel::copyParametersToContext(Context
bool changed = false; bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true; changed = true;
} }
} }
...@@ -1990,10 +1989,10 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system ...@@ -1990,10 +1989,10 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system
for (int i = 0; i < numBonds; ++i) for (int i = 0; i < numBonds; ++i)
force.getBondParameters(i, bondParticles[i], bondParamArray[i]); force.getBondParameters(i, bondParticles[i], bondParamArray[i]);
// Record the tabulated functions for future reference. // Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the interaction. // Create the interaction.
...@@ -2078,8 +2077,8 @@ void ReferenceCalcCustomCompoundBondForceKernel::copyParametersToContext(Context ...@@ -2078,8 +2077,8 @@ void ReferenceCalcCustomCompoundBondForceKernel::copyParametersToContext(Context
bool changed = false; bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true; changed = true;
} }
} }
...@@ -2107,10 +2106,10 @@ void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system ...@@ -2107,10 +2106,10 @@ void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
// Record the tabulated functions for future reference. // Record the tabulated function update counts for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[force.getTabulatedFunctionName(i)] = force.getTabulatedFunction(i).getUpdateCount();
// Create the interaction. // Create the interaction.
...@@ -2158,8 +2157,8 @@ void ReferenceCalcCustomManyParticleForceKernel::copyParametersToContext(Context ...@@ -2158,8 +2157,8 @@ void ReferenceCalcCustomManyParticleForceKernel::copyParametersToContext(Context
bool changed = false; bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) { if (force.getTabulatedFunction(i).getUpdateCount() != tabulatedFunctionUpdateCount[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i)); tabulatedFunctionUpdateCount[name] = force.getTabulatedFunction(i).getUpdateCount();
changed = true; changed = true;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment