Unverified Commit 27bcb657 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

updateParametersInContext() can change tabulated functions (#3307)

* updateParametersInContext() can change tabulated functions

* Fixed error in building C wrappers

* updateParametersInContext() can change tabulated functions for CustomCentroidBondForce

* CustomNonbondedForce can update tabulated functions

* CustomGBForce can update tabulated functions

* CustomManyParticleForce can update tabulated functions

* CustomHbondForce can update tabulated functions
parent d83c2724
...@@ -358,15 +358,16 @@ public: ...@@ -358,15 +358,16 @@ public:
*/ */
const std::string& getTabulatedFunctionName(int index) const; const std::string& getTabulatedFunctionName(int index) const;
/** /**
* Update the per-bond parameters in a Context to match those stored in this Force object. This method provides * Update the per-bond parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it. * an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setBondParameters() to modify this object's parameters, then call updateParametersInContext() * Simply call setBondParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context. * to copy them over to the Context.
* *
* This method has several limitations. The only information it updates is the values of per-bond parameters. * This method has several limitations. The only information it updates is the values of per-bond parameters and tabulated
* All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing * functions. All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. Neither the definitions of groups nor the set of groups involved in a bond can be changed, nor can new * the Context. Neither the definitions of groups nor the set of groups involved in a bond can be changed, nor can new
* bonds be added. * bonds be added. Also, while the tabulated values of a function can change, everything else about it (its dimensions,
* the data range) must not be changed.
*/ */
void updateParametersInContext(Context& context); void updateParametersInContext(Context& context);
/** /**
......
...@@ -338,14 +338,15 @@ public: ...@@ -338,14 +338,15 @@ public:
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Update the per-bond parameters in a Context to match those stored in this Force object. This method provides * Update the per-bond parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it. * an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setBondParameters() to modify this object's parameters, then call updateParametersInContext() * Simply call setBondParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context. * to copy them over to the Context.
* *
* This method has several limitations. The only information it updates is the values of per-bond parameters. * This method has several limitations. The only information it updates is the values of per-bond parameters and tabulated
* All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing * functions. All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. The set of particles involved in a bond cannot be changed, nor can new bonds be added. * the Context. The set of particles involved in a bond cannot be changed, nor can new bonds be added. Also, while the
* tabulated values of a function can change, everything else about it (its dimensions, the data range) must not be changed.
*/ */
void updateParametersInContext(Context& context); void updateParametersInContext(Context& context);
/** /**
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -544,14 +544,15 @@ public: ...@@ -544,14 +544,15 @@ public:
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Update the per-particle parameters in a Context to match those stored in this Force object. This method provides * Update the per-particle parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it. * an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext() * Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context. * to copy them over to the Context.
* *
* This method has several limitations. The only information it updates is the values of per-particle parameters. * This method has several limitations. The only information it updates is the values of per-particle parameters and tabulated
* All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing * functions. All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. Also, this method cannot be used to add new particles, only to change the parameters of existing ones. * the Context. Also, this method cannot be used to add new particles, only to change the parameters of existing ones. While
* the tabulated values of a function can change, everything else about it (its dimensions, the data range) must not be changed.
*/ */
void updateParametersInContext(Context& context); void updateParametersInContext(Context& context);
/** /**
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. * * Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -443,15 +443,16 @@ public: ...@@ -443,15 +443,16 @@ public:
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Update the per-donor and per-acceptor parameters in a Context to match those stored in this Force object. This method * Update the per-donor and per-acceptor parameters and tabulated functions in a Context to match those stored in this Force object. This method
* provides an efficient method to update certain parameters in an existing Context without needing to reinitialize it. * provides an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setDonorParameters() and setAcceptorParameters() to modify this object's parameters, then call * Simply call setDonorParameters() and setAcceptorParameters() to modify this object's parameters, then call
* updateParametersInContext() to copy them over to the Context. * updateParametersInContext() to copy them over to the Context.
* *
* This method has several limitations. The only information it updates is the values of per-donor and per-acceptor parameters. * This method has several limitations. The only information it updates is the values of per-donor and per-acceptor parameters and tabulated
* All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can only * functions. All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can only
* be changed by reinitializing the Context. The set of particles involved in a donor or acceptor cannot be changed, nor can * be changed by reinitializing the Context. The set of particles involved in a donor or acceptor cannot be changed, nor can
* new donors or acceptors be added. * new donors or acceptors be added. While the tabulated values of a function can change, everything else about it (its dimensions,
* the data range) must not be changed.
*/ */
void updateParametersInContext(Context& context); void updateParametersInContext(Context& context);
/** /**
......
...@@ -480,15 +480,16 @@ public: ...@@ -480,15 +480,16 @@ public:
*/ */
const std::string& getTabulatedFunctionName(int index) const; const std::string& getTabulatedFunctionName(int index) const;
/** /**
* Update the per-particle parameters in a Context to match those stored in this Force object. This method provides * Update the per-particle parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it. * an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext() * Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context. * to copy them over to the Context.
* *
* This method has several limitations. The only information it updates is the values of per-particle parameters. * This method has several limitations. The only information it updates is the values of per-particle parameters and tabulated
* All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can * functions. All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can
* only be changed by reinitializing the Context. Also, this method cannot be used to add new particles, only to change * only be changed by reinitializing the Context. Also, this method cannot be used to add new particles, only to change
* the parameters of existing ones. * the parameters of existing ones. While the tabulated values of a function can change, everything else about it (its dimensions,
* the data range) must not be changed.
*/ */
void updateParametersInContext(Context& context); void updateParametersInContext(Context& context);
/** /**
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -490,15 +490,16 @@ public: ...@@ -490,15 +490,16 @@ public:
*/ */
void setInteractionGroupParameters(int index, const std::set<int>& set1, const std::set<int>& set2); void setInteractionGroupParameters(int index, const std::set<int>& set1, const std::set<int>& set2);
/** /**
* Update the per-particle parameters in a Context to match those stored in this Force object. This method provides * Update the per-particle parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it. * an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext() * Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context. * to copy them over to the Context.
* *
* This method has several limitations. The only information it updates is the values of per-particle parameters. * This method has several limitations. The only information it updates is the values of per-particle parameters and tabulated
* All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can * functions. All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can
* only be changed by reinitializing the Context. Also, this method cannot be used to add new particles, only to change * only be changed by reinitializing the Context. Also, this method cannot be used to add new particles, only to change
* the parameters of existing ones. * the parameters of existing ones. While the tabulated values of a function can change, everything else about it (its dimensions,
* the data range) must not be changed.
*/ */
void updateParametersInContext(Context& context); void updateParametersInContext(Context& context);
/** /**
......
...@@ -65,9 +65,12 @@ public: ...@@ -65,9 +65,12 @@ public:
virtual TabulatedFunction* Copy() const = 0; virtual TabulatedFunction* Copy() const = 0;
/** /**
* Get the periodicity status of the tabulated function. * Get the periodicity status of the tabulated function.
*
*/ */
bool getPeriodic() const; bool getPeriodic() const;
virtual bool operator==(const TabulatedFunction& other) const = 0;
virtual bool operator!=(const TabulatedFunction& other) const {
return !(*this == other);
}
protected: protected:
bool periodic; bool periodic;
}; };
...@@ -114,6 +117,7 @@ public: ...@@ -114,6 +117,7 @@ public:
* @deprecated This will be removed in a future release. * @deprecated This will be removed in a future release.
*/ */
Continuous1DFunction* Copy() const; Continuous1DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private: private:
std::vector<double> values; std::vector<double> values;
double min, max; double min, max;
...@@ -176,6 +180,7 @@ public: ...@@ -176,6 +180,7 @@ public:
* @deprecated This will be removed in a future release. * @deprecated This will be removed in a future release.
*/ */
Continuous2DFunction* Copy() const; Continuous2DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private: private:
std::vector<double> values; std::vector<double> values;
int xsize, ysize; int xsize, ysize;
...@@ -254,6 +259,7 @@ public: ...@@ -254,6 +259,7 @@ public:
* @deprecated This will be removed in a future release. * @deprecated This will be removed in a future release.
*/ */
Continuous3DFunction* Copy() const; Continuous3DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private: private:
std::vector<double> values; std::vector<double> values;
int xsize, ysize, zsize; int xsize, ysize, zsize;
...@@ -291,6 +297,7 @@ public: ...@@ -291,6 +297,7 @@ public:
* @deprecated This will be removed in a future release. * @deprecated This will be removed in a future release.
*/ */
Discrete1DFunction* Copy() const; Discrete1DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private: private:
std::vector<double> values; std::vector<double> values;
}; };
...@@ -335,6 +342,7 @@ public: ...@@ -335,6 +342,7 @@ public:
* @deprecated This will be removed in a future release. * @deprecated This will be removed in a future release.
*/ */
Discrete2DFunction* Copy() const; Discrete2DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private: private:
int xsize, ysize; int xsize, ysize;
std::vector<double> values; std::vector<double> values;
...@@ -383,6 +391,7 @@ public: ...@@ -383,6 +391,7 @@ public:
* @deprecated This will be removed in a future release. * @deprecated This will be removed in a future release.
*/ */
Discrete3DFunction* Copy() const; Discrete3DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private: private:
int xsize, ysize, zsize; int xsize, ysize, zsize;
std::vector<double> values; std::vector<double> values;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2014 Stanford University and the Authors. * * Portions copyright (c) 2014-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -75,6 +75,15 @@ Continuous1DFunction* Continuous1DFunction::Copy() const { ...@@ -75,6 +75,15 @@ Continuous1DFunction* Continuous1DFunction::Copy() const {
return new Continuous1DFunction(new_vec, min, max); return new Continuous1DFunction(new_vec, min, max);
} }
bool Continuous1DFunction::operator==(const TabulatedFunction& other) const {
const Continuous1DFunction* fn = dynamic_cast<const Continuous1DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->min != min || fn->max != max)
return false;
return (fn->values == values);
}
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, bool periodic) { Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, bool periodic) {
this->periodic = periodic; this->periodic = periodic;
setFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax); setFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
...@@ -120,6 +129,19 @@ Continuous2DFunction* Continuous2DFunction::Copy() const { ...@@ -120,6 +129,19 @@ Continuous2DFunction* Continuous2DFunction::Copy() const {
return new Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax); return new Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax);
} }
bool Continuous2DFunction::operator==(const TabulatedFunction& other) const {
const Continuous2DFunction* fn = dynamic_cast<const Continuous2DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->xsize != xsize || fn->ysize != ysize)
return false;
if (fn->xmin != xmin || fn->xmax != xmax)
return false;
if (fn->ymin != ymin || fn->ymax != ymax)
return false;
return (fn->values == values);
}
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax, bool periodic) { Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax, bool periodic) {
this->periodic = periodic; this->periodic = periodic;
setFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax); setFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
...@@ -173,6 +195,20 @@ Continuous3DFunction* Continuous3DFunction::Copy() const { ...@@ -173,6 +195,20 @@ Continuous3DFunction* Continuous3DFunction::Copy() const {
return new Continuous3DFunction(xsize, ysize, zsize, new_vec, xmin, xmax, ymin, ymax, zmin, zmax); return new Continuous3DFunction(xsize, ysize, zsize, new_vec, xmin, xmax, ymin, ymax, zmin, zmax);
} }
bool Continuous3DFunction::operator==(const TabulatedFunction& other) const {
const Continuous3DFunction* fn = dynamic_cast<const Continuous3DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->xsize != xsize || fn->ysize != ysize || fn->zsize != zsize)
return false;
if (fn->xmin != xmin || fn->xmax != xmax)
return false;
if (fn->ymin != ymin || fn->ymax != ymax)
return false;
if (fn->zmin != zmin || fn->zmax != zmax)
return false;
return (fn->values == values);
}
Discrete1DFunction::Discrete1DFunction(const vector<double>& values) { Discrete1DFunction::Discrete1DFunction(const vector<double>& values) {
this->values = values; this->values = values;
...@@ -193,6 +229,13 @@ Discrete1DFunction* Discrete1DFunction::Copy() const { ...@@ -193,6 +229,13 @@ Discrete1DFunction* Discrete1DFunction::Copy() const {
return new Discrete1DFunction(new_vec); return new Discrete1DFunction(new_vec);
} }
bool Discrete1DFunction::operator==(const TabulatedFunction& other) const {
const Discrete1DFunction* fn = dynamic_cast<const Discrete1DFunction*>(&other);
if (fn == NULL)
return false;
return (fn->values == values);
}
Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const vector<double>& values) { Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const vector<double>& values) {
if (values.size() != xsize*ysize) if (values.size() != xsize*ysize)
throw OpenMMException("Discrete2DFunction: incorrect number of values"); throw OpenMMException("Discrete2DFunction: incorrect number of values");
...@@ -222,6 +265,15 @@ Discrete2DFunction* Discrete2DFunction::Copy() const { ...@@ -222,6 +265,15 @@ Discrete2DFunction* Discrete2DFunction::Copy() const {
return new Discrete2DFunction(xsize, ysize, new_vec); return new Discrete2DFunction(xsize, ysize, new_vec);
} }
bool Discrete2DFunction::operator==(const TabulatedFunction& other) const {
const Discrete2DFunction* fn = dynamic_cast<const Discrete2DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->xsize != xsize || fn->ysize != ysize)
return false;
return (fn->values == values);
}
Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const vector<double>& values) { Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const vector<double>& values) {
if (values.size() != xsize*ysize*zsize) if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Discrete3DFunction: incorrect number of values"); throw OpenMMException("Discrete3DFunction: incorrect number of values");
...@@ -253,3 +305,12 @@ Discrete3DFunction* Discrete3DFunction::Copy() const { ...@@ -253,3 +305,12 @@ Discrete3DFunction* Discrete3DFunction::Copy() const {
new_vec[i] = values[i]; new_vec[i] = values[i];
return new Discrete3DFunction(xsize, ysize, zsize, new_vec); return new Discrete3DFunction(xsize, ysize, zsize, new_vec);
} }
bool Discrete3DFunction::operator==(const TabulatedFunction& other) const {
const Discrete3DFunction* fn = dynamic_cast<const Discrete3DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->xsize != xsize || fn->ysize != ysize || fn->zsize != zsize)
return false;
return (fn->values == values);
}
...@@ -529,7 +529,8 @@ private: ...@@ -529,7 +529,8 @@ private:
ComputeArray globals; ComputeArray globals;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
const System& system; const System& system;
}; };
...@@ -577,7 +578,8 @@ private: ...@@ -577,7 +578,8 @@ private:
ComputeArray groupForces, bondGroups, centerPositions; ComputeArray groupForces, bondGroups, centerPositions;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::vector<void*> groupForcesArgs; std::vector<void*> groupForcesArgs;
ComputeKernel computeCentersKernel, groupForcesKernel, applyForcesKernel; ComputeKernel computeCentersKernel, groupForcesKernel, applyForcesKernel;
const System& system; const System& system;
...@@ -628,7 +630,8 @@ private: ...@@ -628,7 +630,8 @@ private:
std::vector<void*> interactionGroupArgs; std::vector<void*> interactionGroupArgs;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
double longRangeCoefficient; double longRangeCoefficient;
std::vector<double> longRangeCoefficientDerivs; std::vector<double> longRangeCoefficientDerivs;
bool hasInitializedLongRangeCorrection, hasInitializedKernel, hasParamDerivs, useNeighborList; bool hasInitializedLongRangeCorrection, hasInitializedKernel, hasParamDerivs, useNeighborList;
...@@ -728,7 +731,8 @@ private: ...@@ -728,7 +731,8 @@ private:
ComputeArray longEnergyDerivs, globals, valueBuffers; ComputeArray longEnergyDerivs, globals, valueBuffers;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::vector<bool> pairValueUsesParam, pairEnergyUsesParam, pairEnergyUsesValue; std::vector<bool> pairValueUsesParam, pairEnergyUsesParam, pairEnergyUsesValue;
const System& system; const System& system;
ComputeKernel pairValueKernel, perParticleValueKernel, pairEnergyKernel, perParticleEnergyKernel, gradientChainRuleKernel; ComputeKernel pairValueKernel, perParticleValueKernel, pairEnergyKernel, perParticleEnergyKernel, gradientChainRuleKernel;
...@@ -785,7 +789,8 @@ private: ...@@ -785,7 +789,8 @@ private:
ComputeArray acceptorExclusions; ComputeArray acceptorExclusions;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
const System& system; const System& system;
ComputeKernel donorKernel, acceptorKernel; ComputeKernel donorKernel, acceptorKernel;
}; };
...@@ -836,7 +841,8 @@ private: ...@@ -836,7 +841,8 @@ private:
ComputeArray neighborPairs, numNeighborPairs, neighborStartIndex, numNeighborsForAtom, neighbors; ComputeArray neighborPairs, numNeighborPairs, neighborStartIndex, numNeighborsForAtom, neighbors;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions; std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
const System& system; const System& system;
ComputeKernel forceKernel, blockBoundsKernel, neighborsKernel, startIndicesKernel, copyPairsKernel; ComputeKernel forceKernel, blockBoundsKernel, neighborsKernel, startIndicesKernel, copyPairsKernel;
ComputeEvent event; ComputeEvent event;
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "openmm/internal/CustomCompoundBondForceImpl.h" #include "openmm/internal/CustomCompoundBondForceImpl.h"
#include "openmm/internal/CustomHbondForceImpl.h" #include "openmm/internal/CustomHbondForceImpl.h"
#include "openmm/internal/CustomManyParticleForceImpl.h" #include "openmm/internal/CustomManyParticleForceImpl.h"
#include "openmm/serialization/XmlSerializer.h"
#include "CommonKernelSources.h" #include "CommonKernelSources.h"
#include "lepton/CustomFunction.h" #include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h" #include "lepton/ExpressionTreeNode.h"
...@@ -1289,16 +1290,17 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -1289,16 +1290,17 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
tabulatedFunctions.resize(force.getNumTabulatedFunctions()); tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions[i].initialize<float>(cc, f.size(), "TabulatedFunction"); tabulatedFunctionArrays[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctions[i].upload(f); tabulatedFunctionArrays[i].upload(f);
string arrayName = cc.getBondedUtilities().addArgument(tabulatedFunctions[i], width == 1 ? "float" : "float"+cc.intToString(width)); string arrayName = cc.getBondedUtilities().addArgument(tabulatedFunctionArrays[i], width == 1 ? "float" : "float"+cc.intToString(width));
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
} }
...@@ -1411,6 +1413,18 @@ void CommonCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImp ...@@ -1411,6 +1413,18 @@ void CommonCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImp
} }
params->setParameterValues(paramVector); params->setParameterValues(paramVector);
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid. // Mark that the current reordering may be invalid.
cc.invalidateMolecules(info); cc.invalidateMolecules(info);
...@@ -1535,17 +1549,18 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c ...@@ -1535,17 +1549,18 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
stringstream extraArgs; stringstream extraArgs;
tabulatedFunctions.resize(force.getNumTabulatedFunctions()); tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
string arrayName = "table"+cc.intToString(i); string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions[i].initialize<float>(cc, f.size(), "TabulatedFunction"); tabulatedFunctionArrays[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctions[i].upload(f); tabulatedFunctionArrays[i].upload(f);
extraArgs << ", GLOBAL const float"; extraArgs << ", GLOBAL const float";
if (width > 1) if (width > 1)
extraArgs << width; extraArgs << width;
...@@ -1667,7 +1682,7 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c ...@@ -1667,7 +1682,7 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c
groupForcesKernel->addArg(); // Periodic box information will be set just before it is executed. groupForcesKernel->addArg(); // Periodic box information will be set just before it is executed.
if (needEnergyParamDerivs) if (needEnergyParamDerivs)
groupForcesKernel->addArg(); // Deriv buffer hasn't been created yet. groupForcesKernel->addArg(); // Deriv buffer hasn't been created yet.
for (auto& function : tabulatedFunctions) for (auto& function : tabulatedFunctionArrays)
groupForcesKernel->addArg(function); groupForcesKernel->addArg(function);
if (globals.isInitialized()) if (globals.isInitialized())
groupForcesKernel->addArg(globals); groupForcesKernel->addArg(globals);
...@@ -1728,6 +1743,18 @@ void CommonCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImp ...@@ -1728,6 +1743,18 @@ void CommonCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImp
} }
params->setParameterValues(paramVector); params->setParameterValues(paramVector);
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid. // Mark that the current reordering may be invalid.
cc.invalidateMolecules(info); cc.invalidateMolecules(info);
...@@ -1868,18 +1895,19 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -1868,18 +1895,19 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
vector<string> tableTypes; vector<string> tableTypes;
tabulatedFunctions.resize(force.getNumTabulatedFunctions()); tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
string arrayName = prefix+"table"+cc.intToString(i); string arrayName = prefix+"table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions[i].initialize<float>(cc, f.size(), "TabulatedFunction"); tabulatedFunctionArrays[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctions[i].upload(f); tabulatedFunctionArrays[i].upload(f);
cc.getNonbondedUtilities().addArgument(ComputeParameterInfo(tabulatedFunctions[i], arrayName, "float", width)); cc.getNonbondedUtilities().addArgument(ComputeParameterInfo(tabulatedFunctionArrays[i], arrayName, "float", width));
if (width == 1) if (width == 1)
tableTypes.push_back("float"); tableTypes.push_back("float");
else else
...@@ -2166,7 +2194,7 @@ void CommonCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNon ...@@ -2166,7 +2194,7 @@ void CommonCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNon
stringstream args; stringstream args;
for (int i = 0; i < (int) buffers.size(); i++) for (int i = 0; i < (int) buffers.size(); i++)
args<<", GLOBAL const "<<buffers[i].getType()<<"* RESTRICT global_params"<<(i+1); args<<", GLOBAL const "<<buffers[i].getType()<<"* RESTRICT global_params"<<(i+1);
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctionArrays.size(); i++)
args << ", GLOBAL const " << tableTypes[i]<< "* RESTRICT table" << i; args << ", GLOBAL const " << tableTypes[i]<< "* RESTRICT table" << i;
if (globals.isInitialized()) if (globals.isInitialized())
args<<", GLOBAL const float* RESTRICT globals"; args<<", GLOBAL const float* RESTRICT globals";
...@@ -2289,7 +2317,7 @@ double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool ...@@ -2289,7 +2317,7 @@ double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool
interactionGroupKernel->addArg(); // Periodic box information will be set just before it is executed. interactionGroupKernel->addArg(); // Periodic box information will be set just before it is executed.
for (auto& parameter : params->getParameterInfos()) for (auto& parameter : params->getParameterInfos())
interactionGroupKernel->addArg(parameter.getArray()); interactionGroupKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctions) for (auto& function : tabulatedFunctionArrays)
interactionGroupKernel->addArg(function); interactionGroupKernel->addArg(function);
if (globals.isInitialized()) if (globals.isInitialized())
interactionGroupKernel->addArg(globals); interactionGroupKernel->addArg(globals);
...@@ -2352,6 +2380,18 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& ...@@ -2352,6 +2380,18 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
*forceCopy = force; *forceCopy = force;
} }
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid. // Mark that the current reordering may be invalid.
cc.invalidateMolecules(info); cc.invalidateMolecules(info);
...@@ -2679,18 +2719,19 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2679,18 +2719,19 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
tabulatedFunctions.resize(force.getNumTabulatedFunctions()); tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
string arrayName = prefix+"table"+cc.intToString(i); string arrayName = prefix+"table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions[i].initialize<float>(cc, f.size(), "TabulatedFunction"); tabulatedFunctionArrays[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctions[i].upload(f); tabulatedFunctionArrays[i].upload(f);
nb.addArgument(ComputeParameterInfo(tabulatedFunctions[i], arrayName, "float", width)); nb.addArgument(ComputeParameterInfo(tabulatedFunctionArrays[i], arrayName, "float", width));
tableArgs << ", GLOBAL const float"; tableArgs << ", GLOBAL const float";
if (width > 1) if (width > 1)
tableArgs << width; tableArgs << width;
...@@ -3510,7 +3551,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3510,7 +3551,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
} }
for (auto& d : dValue0dParam) for (auto& d : dValue0dParam)
pairValueKernel->addArg(d); pairValueKernel->addArg(d);
for (auto& function : tabulatedFunctions) for (auto& function : tabulatedFunctionArrays)
pairValueKernel->addArg(function); pairValueKernel->addArg(function);
perParticleValueKernel->addArg(cc.getPosq()); perParticleValueKernel->addArg(cc.getPosq());
perParticleValueKernel->addArg(valueBuffers); perParticleValueKernel->addArg(valueBuffers);
...@@ -3529,7 +3570,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3529,7 +3570,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
for (int j = 0; j < dValuedParam[i]->getParameterInfos().size(); j++) for (int j = 0; j < dValuedParam[i]->getParameterInfos().size(); j++)
perParticleValueKernel->addArg(dValuedParam[i]->getParameterInfos()[j].getArray()); perParticleValueKernel->addArg(dValuedParam[i]->getParameterInfos()[j].getArray());
} }
for (auto& function : tabulatedFunctions) for (auto& function : tabulatedFunctionArrays)
perParticleValueKernel->addArg(function); perParticleValueKernel->addArg(function);
pairEnergyKernel->addArg(useLong ? cc.getLongForceBuffer() : cc.getForceBuffers()); pairEnergyKernel->addArg(useLong ? cc.getLongForceBuffer() : cc.getForceBuffers());
pairEnergyKernel->addArg(cc.getEnergyBuffer()); pairEnergyKernel->addArg(cc.getEnergyBuffer());
...@@ -3570,7 +3611,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3570,7 +3611,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
pairEnergyKernel->addArg(buffer.getArray()); pairEnergyKernel->addArg(buffer.getArray());
if (needEnergyParamDerivs) if (needEnergyParamDerivs)
pairEnergyKernel->addArg(cc.getEnergyParamDerivBuffer()); pairEnergyKernel->addArg(cc.getEnergyParamDerivBuffer());
for (auto& function : tabulatedFunctions) for (auto& function : tabulatedFunctionArrays)
pairEnergyKernel->addArg(function); pairEnergyKernel->addArg(function);
perParticleEnergyKernel->addArg(cc.getEnergyBuffer()); perParticleEnergyKernel->addArg(cc.getEnergyBuffer());
perParticleEnergyKernel->addArg(cc.getPosq()); perParticleEnergyKernel->addArg(cc.getPosq());
...@@ -3595,7 +3636,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3595,7 +3636,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
perParticleEnergyKernel->addArg(longEnergyDerivs); perParticleEnergyKernel->addArg(longEnergyDerivs);
if (needEnergyParamDerivs) if (needEnergyParamDerivs)
perParticleEnergyKernel->addArg(cc.getEnergyParamDerivBuffer()); perParticleEnergyKernel->addArg(cc.getEnergyParamDerivBuffer());
for (auto& function : tabulatedFunctions) for (auto& function : tabulatedFunctionArrays)
perParticleEnergyKernel->addArg(function); perParticleEnergyKernel->addArg(function);
if (needParameterGradient || needEnergyParamDerivs) { if (needParameterGradient || needEnergyParamDerivs) {
gradientChainRuleKernel->addArg(cc.getPosq()); gradientChainRuleKernel->addArg(cc.getPosq());
...@@ -3614,7 +3655,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include ...@@ -3614,7 +3655,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
for (auto& buffer : d->getParameterInfos()) for (auto& buffer : d->getParameterInfos())
gradientChainRuleKernel->addArg(buffer.getArray()); gradientChainRuleKernel->addArg(buffer.getArray());
} }
for (auto& function : tabulatedFunctions) for (auto& function : tabulatedFunctionArrays)
gradientChainRuleKernel->addArg(function); gradientChainRuleKernel->addArg(function);
} }
} }
...@@ -3665,6 +3706,18 @@ void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context ...@@ -3665,6 +3706,18 @@ void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context
} }
params->setParameterValues(paramVector); params->setParameterValues(paramVector);
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid. // Mark that the current reordering may be invalid.
cc.invalidateMolecules(info); cc.invalidateMolecules(info);
...@@ -3880,17 +3933,18 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -3880,17 +3933,18 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
tabulatedFunctions.resize(force.getNumTabulatedFunctions()); tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
string arrayName = "table"+cc.intToString(i); string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions[i].initialize<float>(cc, f.size(), "TabulatedFunction"); tabulatedFunctionArrays[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctions[i].upload(f); tabulatedFunctionArrays[i].upload(f);
tableArgs << ", GLOBAL const float"; tableArgs << ", GLOBAL const float";
if (width > 1) if (width > 1)
tableArgs << width; tableArgs << width;
...@@ -4132,7 +4186,7 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl ...@@ -4132,7 +4186,7 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl
donorKernel->addArg(parameter.getArray()); donorKernel->addArg(parameter.getArray());
for (auto& parameter : acceptorParams->getParameterInfos()) for (auto& parameter : acceptorParams->getParameterInfos())
donorKernel->addArg(parameter.getArray()); donorKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctions) for (auto& function : tabulatedFunctionArrays)
donorKernel->addArg(function); donorKernel->addArg(function);
if (cc.getSupports64BitGlobalAtomics()) if (cc.getSupports64BitGlobalAtomics())
acceptorKernel->addArg(cc.getLongForceBuffer()); acceptorKernel->addArg(cc.getLongForceBuffer());
...@@ -4153,7 +4207,7 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl ...@@ -4153,7 +4207,7 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl
acceptorKernel->addArg(parameter.getArray()); acceptorKernel->addArg(parameter.getArray());
for (auto& parameter : acceptorParams->getParameterInfos()) for (auto& parameter : acceptorParams->getParameterInfos())
acceptorKernel->addArg(parameter.getArray()); acceptorKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctions) for (auto& function : tabulatedFunctionArrays)
acceptorKernel->addArg(function); acceptorKernel->addArg(function);
} }
setPeriodicBoxArgs(cc, donorKernel, cc.getSupports64BitGlobalAtomics() ? 6 : 7); setPeriodicBoxArgs(cc, donorKernel, cc.getSupports64BitGlobalAtomics() ? 6 : 7);
...@@ -4203,6 +4257,18 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont ...@@ -4203,6 +4257,18 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont
acceptorParams->setParameterValues(acceptorParamVector); acceptorParams->setParameterValues(acceptorParamVector);
} }
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid. // Mark that the current reordering may be invalid.
cc.invalidateMolecules(info); cc.invalidateMolecules(info);
...@@ -4280,17 +4346,18 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c ...@@ -4280,17 +4346,18 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
stringstream tableArgs; stringstream tableArgs;
tabulatedFunctions.resize(force.getNumTabulatedFunctions()); tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
string arrayName = "table"+cc.intToString(i); string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i)); functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width; int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width); vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions[i].initialize<float>(cc, f.size(), "TabulatedFunction"); tabulatedFunctionArrays[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctions[i].upload(f); tabulatedFunctionArrays[i].upload(f);
tableArgs << ", GLOBAL const float"; tableArgs << ", GLOBAL const float";
if (width > 1) if (width > 1)
tableArgs << width; tableArgs << width;
...@@ -4593,7 +4660,7 @@ double CommonCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bo ...@@ -4593,7 +4660,7 @@ double CommonCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bo
forceKernel->addArg(globals); forceKernel->addArg(globals);
for (auto& parameter : params->getParameterInfos()) for (auto& parameter : params->getParameterInfos())
forceKernel->addArg(parameter.getArray()); forceKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctions) for (auto& function : tabulatedFunctionArrays)
forceKernel->addArg(function); forceKernel->addArg(function);
if (nonbondedMethod != NoCutoff) { if (nonbondedMethod != NoCutoff) {
...@@ -4709,6 +4776,18 @@ void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImp ...@@ -4709,6 +4776,18 @@ void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImp
} }
params->setParameterValues(paramVector); params->setParameterValues(paramVector);
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid. // Mark that the current reordering may be invalid.
cc.invalidateMolecules(info); cc.invalidateMolecules(info);
......
...@@ -321,6 +321,7 @@ public: ...@@ -321,6 +321,7 @@ public:
*/ */
void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force); void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force);
private: private:
void createInteraction(const CustomNonbondedForce& force);
CpuPlatform::PlatformData& data; CpuPlatform::PlatformData& data;
int numParticles; int numParticles;
std::vector<std::vector<double> > particleParamArray; std::vector<std::vector<double> > particleParamArray;
...@@ -333,6 +334,7 @@ private: ...@@ -333,6 +334,7 @@ private:
std::vector<std::string> parameterNames, globalParameterNames, energyParamDerivNames; std::vector<std::string> parameterNames, globalParameterNames, energyParamDerivNames;
std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups; std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups;
std::vector<double> longRangeCoefficientDerivs; std::vector<double> longRangeCoefficientDerivs;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
CpuCustomNonbondedForce* nonbonded; CpuCustomNonbondedForce* nonbonded;
}; };
...@@ -410,6 +412,7 @@ public: ...@@ -410,6 +412,7 @@ public:
*/ */
void copyParametersToContext(ContextImpl& context, const CustomGBForce& force); void copyParametersToContext(ContextImpl& context, const CustomGBForce& force);
private: private:
void createInteraction(const CustomGBForce& force);
CpuPlatform::PlatformData& data; CpuPlatform::PlatformData& data;
int numParticles; int numParticles;
bool isPeriodic; bool isPeriodic;
...@@ -421,6 +424,7 @@ private: ...@@ -421,6 +424,7 @@ private:
std::vector<std::string> particleParameterNames, globalParameterNames, energyParamDerivNames, valueNames; std::vector<std::string> particleParameterNames, globalParameterNames, energyParamDerivNames, valueNames;
std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes; std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes;
std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes; std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
}; };
...@@ -463,6 +467,7 @@ private: ...@@ -463,6 +467,7 @@ private:
std::vector<std::vector<double> > particleParamArray; std::vector<std::vector<double> > particleParamArray;
CpuCustomManyParticleForce* ixn; CpuCustomManyParticleForce* ixn;
std::vector<std::string> globalParameterNames; std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
}; };
......
...@@ -45,6 +45,7 @@ ...@@ -45,6 +45,7 @@
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/internal/NonbondedForceImpl.h" #include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/vectorize.h" #include "openmm/internal/vectorize.h"
#include "openmm/serialization/XmlSerializer.h"
#include "lepton/CompiledExpression.h" #include "lepton/CompiledExpression.h"
#include "lepton/CustomFunction.h" #include "lepton/CustomFunction.h"
#include "lepton/Operation.h" #include "lepton/Operation.h"
...@@ -868,7 +869,6 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C ...@@ -868,7 +869,6 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
// Build the arrays. // Build the arrays.
int numParameters = force.getNumPerParticleParameters();
particleParamArray.resize(numParticles); particleParamArray.resize(numParticles);
for (int i = 0; i < numParticles; ++i) for (int i = 0; i < numParticles; ++i)
force.getParticleParameters(i, particleParamArray[i]); force.getParticleParameters(i, particleParamArray[i]);
...@@ -882,10 +882,41 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C ...@@ -882,10 +882,41 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
switchingDistance = force.getSwitchingDistance(); switchingDistance = force.getSwitchingDistance();
} }
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Record information for the long range correction.
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) {
forceCopy = new CustomNonbondedForce(force);
hasInitializedLongRangeCorrection = false;
}
else {
longRangeCoefficient = 0.0;
hasInitializedLongRangeCorrection = true;
}
// Record the interaction groups.
for (int i = 0; i < force.getNumInteractionGroups(); i++) {
set<int> set1, set2;
force.getInteractionGroupParameters(i, set1, set2);
interactionGroups.push_back(make_pair(set1, set2));
}
data.isPeriodic |= (nonbondedMethod == CutoffPeriodic);
// Create the interaction.
createInteraction(force);
}
void CpuCalcCustomNonbondedForceKernel::createInteraction(const CustomNonbondedForce& force) {
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Parse the various expressions used to calculate the force. // Parse the various expressions used to calculate the force.
...@@ -893,7 +924,7 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C ...@@ -893,7 +924,7 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize(); Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
Lepton::CompiledExpression energyExpression = expression.createCompiledExpression(); Lepton::CompiledExpression energyExpression = expression.createCompiledExpression();
Lepton::CompiledExpression forceExpression = expression.differentiate("r").createCompiledExpression(); Lepton::CompiledExpression forceExpression = expression.differentiate("r").createCompiledExpression();
for (int i = 0; i < numParameters; i++) for (int i = 0; i < force.getNumPerParticleParameters(); i++)
parameterNames.push_back(force.getPerParticleParameterName(i)); parameterNames.push_back(force.getPerParticleParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
...@@ -907,7 +938,7 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C ...@@ -907,7 +938,7 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
} }
set<string> variables; set<string> variables;
variables.insert("r"); variables.insert("r");
for (int i = 0; i < numParameters; i++) { for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
variables.insert(parameterNames[i]+"1"); variables.insert(parameterNames[i]+"1");
variables.insert(parameterNames[i]+"2"); variables.insert(parameterNames[i]+"2");
} }
...@@ -919,25 +950,8 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C ...@@ -919,25 +950,8 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
for (auto& function : functions) for (auto& function : functions)
delete function.second; delete function.second;
// Record information for the long range correction. // Create the object that computes the interaction.
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) {
forceCopy = new CustomNonbondedForce(force);
hasInitializedLongRangeCorrection = false;
}
else {
longRangeCoefficient = 0.0;
hasInitializedLongRangeCorrection = true;
}
// Record the interaction groups.
for (int i = 0; i < force.getNumInteractionGroups(); i++) {
set<int> set1, set2;
force.getInteractionGroupParameters(i, set1, set2);
interactionGroups.push_back(make_pair(set1, set2));
}
data.isPeriodic |= (nonbondedMethod == CutoffPeriodic);
nonbonded = new CpuCustomNonbondedForce(energyExpression, forceExpression, parameterNames, exclusions, energyParamDerivExpressions, data.threads); nonbonded = new CpuCustomNonbondedForce(energyExpression, forceExpression, parameterNames, exclusions, energyParamDerivExpressions, data.threads);
if (interactionGroups.size() > 0) if (interactionGroups.size() > 0)
nonbonded->setInteractionGroups(interactionGroups); nonbonded->setInteractionGroups(interactionGroups);
...@@ -1011,6 +1025,22 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con ...@@ -1011,6 +1025,22 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con
hasInitializedLongRangeCorrection = true; hasInitializedLongRangeCorrection = true;
*forceCopy = force; *forceCopy = force;
} }
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete nonbonded;
nonbonded = NULL;
createInteraction(force);
}
} }
CpuCalcGBSAOBCForceKernel::~CpuCalcGBSAOBCForceKernel() { CpuCalcGBSAOBCForceKernel::~CpuCalcGBSAOBCForceKernel() {
...@@ -1101,11 +1131,10 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB ...@@ -1101,11 +1131,10 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
// Build the arrays. // Build the arrays.
int numPerParticleParameters = force.getNumPerParticleParameters();
particleParamArray.resize(numParticles); particleParamArray.resize(numParticles);
for (int i = 0; i < numParticles; ++i) for (int i = 0; i < numParticles; ++i)
force.getParticleParameters(i, particleParamArray[i]); force.getParticleParameters(i, particleParamArray[i]);
for (int i = 0; i < numPerParticleParameters; i++) for (int i = 0; i < force.getNumPerParticleParameters(); i++)
particleParameterNames.push_back(force.getPerParticleParameterName(i)); particleParameterNames.push_back(force.getPerParticleParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
...@@ -1113,15 +1142,30 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB ...@@ -1113,15 +1142,30 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
nonbondedCutoff = force.getCutoffDistance(); nonbondedCutoff = force.getCutoffDistance();
if (nonbondedMethod != NoCutoff) if (nonbondedMethod != NoCutoff)
neighborList = new CpuNeighborList(4); neighborList = new CpuNeighborList(4);
data.isPeriodic |= (force.getNonbondedMethod() == CustomGBForce::CutoffPeriodic);
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
createInteraction(force);
}
void CpuCalcCustomGBForceKernel::createInteraction(const CustomGBForce& force) {
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Parse the expressions for computed values. // Parse the expressions for computed values.
valueTypes.clear();
valueNames.clear();
energyParamDerivNames.clear();
vector<vector<Lepton::CompiledExpression> > valueDerivExpressions(force.getNumComputedValues()); vector<vector<Lepton::CompiledExpression> > valueDerivExpressions(force.getNumComputedValues());
vector<vector<Lepton::CompiledExpression> > valueGradientExpressions(force.getNumComputedValues()); vector<vector<Lepton::CompiledExpression> > valueGradientExpressions(force.getNumComputedValues());
vector<vector<Lepton::CompiledExpression> > valueParamDerivExpressions(force.getNumComputedValues()); vector<vector<Lepton::CompiledExpression> > valueParamDerivExpressions(force.getNumComputedValues());
...@@ -1132,7 +1176,7 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB ...@@ -1132,7 +1176,7 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
particleVariables.insert("x"); particleVariables.insert("x");
particleVariables.insert("y"); particleVariables.insert("y");
particleVariables.insert("z"); particleVariables.insert("z");
for (int i = 0; i < numPerParticleParameters; i++) { for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
particleVariables.insert(particleParameterNames[i]); particleVariables.insert(particleParameterNames[i]);
pairVariables.insert(particleParameterNames[i]+"1"); pairVariables.insert(particleParameterNames[i]+"1");
pairVariables.insert(particleParameterNames[i]+"2"); pairVariables.insert(particleParameterNames[i]+"2");
...@@ -1171,6 +1215,7 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB ...@@ -1171,6 +1215,7 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
// Parse the expressions for energy terms. // Parse the expressions for energy terms.
energyTypes.clear();
vector<vector<Lepton::CompiledExpression> > energyDerivExpressions(force.getNumEnergyTerms()); vector<vector<Lepton::CompiledExpression> > energyDerivExpressions(force.getNumEnergyTerms());
vector<vector<Lepton::CompiledExpression> > energyGradientExpressions(force.getNumEnergyTerms()); vector<vector<Lepton::CompiledExpression> > energyGradientExpressions(force.getNumEnergyTerms());
vector<vector<Lepton::CompiledExpression> > energyParamDerivExpressions(force.getNumEnergyTerms()); vector<vector<Lepton::CompiledExpression> > energyParamDerivExpressions(force.getNumEnergyTerms());
...@@ -1208,7 +1253,6 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB ...@@ -1208,7 +1253,6 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
ixn = new CpuCustomGBForce(numParticles, exclusions, valueExpressions, valueDerivExpressions, valueGradientExpressions, valueParamDerivExpressions, ixn = new CpuCustomGBForce(numParticles, exclusions, valueExpressions, valueDerivExpressions, valueGradientExpressions, valueParamDerivExpressions,
valueNames, valueTypes, energyExpressions, energyDerivExpressions, energyGradientExpressions, energyParamDerivExpressions, energyTypes, valueNames, valueTypes, energyExpressions, energyDerivExpressions, energyGradientExpressions, energyParamDerivExpressions, energyTypes,
particleParameterNames, data.threads); particleParameterNames, data.threads);
data.isPeriodic |= (force.getNonbondedMethod() == CustomGBForce::CutoffPeriodic);
} }
double CpuCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CpuCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
...@@ -1247,6 +1291,22 @@ void CpuCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, c ...@@ -1247,6 +1291,22 @@ void CpuCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, c
for (int j = 0; j < numParameters; j++) for (int j = 0; j < numParameters; j++)
particleParamArray[i][j] = static_cast<double>(parameters[j]); particleParamArray[i][j] = static_cast<double>(parameters[j]);
} }
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
createInteraction(force);
}
} }
CpuCalcCustomManyParticleForceKernel::~CpuCalcCustomManyParticleForceKernel() { CpuCalcCustomManyParticleForceKernel::~CpuCalcCustomManyParticleForceKernel() {
...@@ -1266,6 +1326,14 @@ void CpuCalcCustomManyParticleForceKernel::initialize(const System& system, cons ...@@ -1266,6 +1326,14 @@ void CpuCalcCustomManyParticleForceKernel::initialize(const System& system, cons
} }
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
ixn = new CpuCustomManyParticleForce(force, data.threads); ixn = new CpuCustomManyParticleForce(force, data.threads);
nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod()); nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod());
cutoffDistance = force.getCutoffDistance(); cutoffDistance = force.getCutoffDistance();
...@@ -1303,6 +1371,22 @@ void CpuCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl& ...@@ -1303,6 +1371,22 @@ void CpuCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl&
for (int j = 0; j < numParameters; j++) for (int j = 0; j < numParameters; j++)
particleParamArray[i][j] = static_cast<double>(parameters[j]); particleParamArray[i][j] = static_cast<double>(parameters[j]);
} }
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
ixn = new CpuCustomManyParticleForce(force, data.threads);
}
} }
CpuCalcGayBerneForceKernel::~CpuCalcGayBerneForceKernel() { CpuCalcGayBerneForceKernel::~CpuCalcGayBerneForceKernel() {
......
...@@ -682,6 +682,7 @@ public: ...@@ -682,6 +682,7 @@ public:
*/ */
void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force); void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force);
private: private:
void createExpressions(const CustomNonbondedForce& force);
int numParticles; int numParticles;
std::vector<std::vector<double> > particleParamArray; std::vector<std::vector<double> > particleParamArray;
double nonbondedCutoff, switchingDistance, periodicBoxSize[3], longRangeCoefficient; double nonbondedCutoff, switchingDistance, periodicBoxSize[3], longRangeCoefficient;
...@@ -695,6 +696,7 @@ private: ...@@ -695,6 +696,7 @@ private:
std::vector<std::string> parameterNames, globalParameterNames, energyParamDerivNames; std::vector<std::string> parameterNames, globalParameterNames, energyParamDerivNames;
std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups; std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups;
std::vector<double> longRangeCoefficientDerivs; std::vector<double> longRangeCoefficientDerivs;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
NeighborList* neighborList; NeighborList* neighborList;
}; };
...@@ -768,6 +770,7 @@ public: ...@@ -768,6 +770,7 @@ public:
*/ */
void copyParametersToContext(ContextImpl& context, const CustomGBForce& force); void copyParametersToContext(ContextImpl& context, const CustomGBForce& force);
private: private:
void createExpressions(const CustomGBForce& force);
int numParticles; int numParticles;
bool isPeriodic; bool isPeriodic;
std::vector<std::vector<double> > particleParamArray; std::vector<std::vector<double> > particleParamArray;
...@@ -784,6 +787,7 @@ private: ...@@ -784,6 +787,7 @@ private:
std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions; std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyParamDerivExpressions; std::vector<std::vector<Lepton::CompiledExpression> > energyParamDerivExpressions;
std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes; std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
NeighborList* neighborList; NeighborList* neighborList;
}; };
...@@ -861,13 +865,16 @@ public: ...@@ -861,13 +865,16 @@ public:
*/ */
void copyParametersToContext(ContextImpl& context, const CustomHbondForce& force); void copyParametersToContext(ContextImpl& context, const CustomHbondForce& force);
private: private:
void createInteraction(const CustomHbondForce& force);
int numDonors, numAcceptors, numParticles; int numDonors, numAcceptors, numParticles;
bool isPeriodic; bool isPeriodic;
std::vector<std::vector<int> > donorParticles, acceptorParticles;
std::vector<std::vector<double> > donorParamArray, acceptorParamArray; std::vector<std::vector<double> > donorParamArray, acceptorParamArray;
double nonbondedCutoff; double nonbondedCutoff;
ReferenceCustomHbondIxn* ixn; ReferenceCustomHbondIxn* ixn;
std::vector<std::set<int> > exclusions; std::vector<std::set<int> > exclusions;
std::vector<std::string> globalParameterNames; std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
}; };
/** /**
...@@ -902,10 +909,15 @@ public: ...@@ -902,10 +909,15 @@ public:
*/ */
void copyParametersToContext(ContextImpl& context, const CustomCentroidBondForce& force); void copyParametersToContext(ContextImpl& context, const CustomCentroidBondForce& force);
private: private:
void createInteraction(const CustomCentroidBondForce& force);
int numBonds, numParticles; int numBonds, numParticles;
std::vector<std::vector<int> > bondGroups;
std::vector<std::vector<int> > groupAtoms;
std::vector<std::vector<double> > normalizedWeights;
std::vector<std::vector<double> > bondParamArray; std::vector<std::vector<double> > bondParamArray;
ReferenceCustomCentroidBondIxn* ixn; ReferenceCustomCentroidBondIxn* ixn;
std::vector<std::string> globalParameterNames, energyParamDerivNames; std::vector<std::string> globalParameterNames, energyParamDerivNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
bool usePeriodic; bool usePeriodic;
Vec3* boxVectors; Vec3* boxVectors;
}; };
...@@ -942,10 +954,13 @@ public: ...@@ -942,10 +954,13 @@ public:
*/ */
void copyParametersToContext(ContextImpl& context, const CustomCompoundBondForce& force); void copyParametersToContext(ContextImpl& context, const CustomCompoundBondForce& force);
private: private:
void createInteraction(const CustomCompoundBondForce& force);
int numBonds; int numBonds;
std::vector<std::vector<int> > bondParticles;
std::vector<std::vector<double> > bondParamArray; std::vector<std::vector<double> > bondParamArray;
ReferenceCustomCompoundBondIxn* ixn; ReferenceCustomCompoundBondIxn* ixn;
std::vector<std::string> globalParameterNames, energyParamDerivNames; std::vector<std::string> globalParameterNames, energyParamDerivNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
bool usePeriodic; bool usePeriodic;
Vec3* boxVectors; Vec3* boxVectors;
}; };
...@@ -987,6 +1002,7 @@ private: ...@@ -987,6 +1002,7 @@ private:
std::vector<std::vector<double> > particleParamArray; std::vector<std::vector<double> > particleParamArray;
ReferenceCustomManyParticleIxn* ixn; ReferenceCustomManyParticleIxn* ixn;
std::vector<std::string> globalParameterNames; std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
}; };
......
...@@ -80,6 +80,7 @@ ...@@ -80,6 +80,7 @@
#include "openmm/internal/NonbondedForceImpl.h" #include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/Integrator.h" #include "openmm/Integrator.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/serialization/XmlSerializer.h"
#include "SimTKOpenMMUtilities.h" #include "SimTKOpenMMUtilities.h"
#include "lepton/CustomFunction.h" #include "lepton/CustomFunction.h"
#include "lepton/Operation.h" #include "lepton/Operation.h"
...@@ -1151,7 +1152,6 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c ...@@ -1151,7 +1152,6 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
// Build the arrays. // Build the arrays.
int numParameters = force.getNumPerParticleParameters();
particleParamArray.resize(numParticles); particleParamArray.resize(numParticles);
for (int i = 0; i < numParticles; ++i) for (int i = 0; i < numParticles; ++i)
force.getParticleParameters(i, particleParamArray[i]); force.getParticleParameters(i, particleParamArray[i]);
...@@ -1167,10 +1167,40 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c ...@@ -1167,10 +1167,40 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
switchingDistance = force.getSwitchingDistance(); switchingDistance = force.getSwitchingDistance();
} }
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the expressions.
createExpressions(force);
// Record information for the long range correction.
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) {
forceCopy = new CustomNonbondedForce(force);
hasInitializedLongRangeCorrection = false;
}
else {
longRangeCoefficient = 0.0;
hasInitializedLongRangeCorrection = true;
}
// Record the interaction groups.
for (int i = 0; i < force.getNumInteractionGroups(); i++) {
set<int> set1, set2;
force.getInteractionGroupParameters(i, set1, set2);
interactionGroups.push_back(make_pair(set1, set2));
}
}
void ReferenceCalcCustomNonbondedForceKernel::createExpressions(const CustomNonbondedForce& force) {
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Parse the various expressions used to calculate the force. // Parse the various expressions used to calculate the force.
...@@ -1178,7 +1208,12 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c ...@@ -1178,7 +1208,12 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize(); Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
energyExpression = expression.createCompiledExpression(); energyExpression = expression.createCompiledExpression();
forceExpression = expression.differentiate("r").createCompiledExpression(); forceExpression = expression.differentiate("r").createCompiledExpression();
for (int i = 0; i < numParameters; i++) parameterNames.clear();
globalParameterNames.clear();
globalParamValues.clear();
energyParamDerivNames.clear();
energyParamDerivExpressions.clear();
for (int i = 0; i < force.getNumPerParticleParameters(); i++)
parameterNames.push_back(force.getPerParticleParameterName(i)); parameterNames.push_back(force.getPerParticleParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++) { for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
...@@ -1191,7 +1226,7 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c ...@@ -1191,7 +1226,7 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
} }
set<string> variables; set<string> variables;
variables.insert("r"); variables.insert("r");
for (int i = 0; i < numParameters; i++) { for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
variables.insert(parameterNames[i]+"1"); variables.insert(parameterNames[i]+"1");
variables.insert(parameterNames[i]+"2"); variables.insert(parameterNames[i]+"2");
} }
...@@ -1202,25 +1237,6 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c ...@@ -1202,25 +1237,6 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
for (auto& function : functions) for (auto& function : functions)
delete function.second; delete function.second;
// Record information for the long range correction.
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) {
forceCopy = new CustomNonbondedForce(force);
hasInitializedLongRangeCorrection = false;
}
else {
longRangeCoefficient = 0.0;
hasInitializedLongRangeCorrection = true;
}
// Record the interaction groups.
for (int i = 0; i < force.getNumInteractionGroups(); i++) {
set<int> set1, set2;
force.getInteractionGroupParameters(i, set1, set2);
interactionGroups.push_back(make_pair(set1, set2));
}
} }
double ReferenceCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double ReferenceCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
...@@ -1300,6 +1316,19 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp ...@@ -1300,6 +1316,19 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
hasInitializedLongRangeCorrection = true; hasInitializedLongRangeCorrection = true;
*forceCopy = force; *forceCopy = force;
} }
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed)
createExpressions(force);
} }
ReferenceCalcGBSAOBCForceKernel::~ReferenceCalcGBSAOBCForceKernel() { ReferenceCalcGBSAOBCForceKernel::~ReferenceCalcGBSAOBCForceKernel() {
...@@ -1395,11 +1424,10 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1395,11 +1424,10 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
// Build the arrays. // Build the arrays.
int numPerParticleParameters = force.getNumPerParticleParameters();
particleParamArray.resize(numParticles); particleParamArray.resize(numParticles);
for (int i = 0; i < numParticles; ++i) for (int i = 0; i < numParticles; ++i)
force.getParticleParameters(i, particleParamArray[i]); force.getParticleParameters(i, particleParamArray[i]);
for (int i = 0; i < numPerParticleParameters; i++) for (int i = 0; i < force.getNumPerParticleParameters(); i++)
particleParameterNames.push_back(force.getPerParticleParameterName(i)); particleParameterNames.push_back(force.getPerParticleParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
...@@ -1410,14 +1438,32 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1410,14 +1438,32 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
else else
neighborList = new NeighborList(); neighborList = new NeighborList();
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the expressions.
createExpressions(force);
}
void ReferenceCalcCustomGBForceKernel::createExpressions(const CustomGBForce& force) {
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Parse the expressions for computed values. // Parse the expressions for computed values.
valueExpressions.clear();
valueTypes.clear();
valueNames.clear();
energyParamDerivNames.clear();
valueDerivExpressions.clear();
valueGradientExpressions.clear();
valueParamDerivExpressions.clear();
valueDerivExpressions.resize(force.getNumComputedValues()); valueDerivExpressions.resize(force.getNumComputedValues());
valueGradientExpressions.resize(force.getNumComputedValues()); valueGradientExpressions.resize(force.getNumComputedValues());
valueParamDerivExpressions.resize(force.getNumComputedValues()); valueParamDerivExpressions.resize(force.getNumComputedValues());
...@@ -1426,7 +1472,7 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1426,7 +1472,7 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
particleVariables.insert("x"); particleVariables.insert("x");
particleVariables.insert("y"); particleVariables.insert("y");
particleVariables.insert("z"); particleVariables.insert("z");
for (int i = 0; i < numPerParticleParameters; i++) { for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
particleVariables.insert(particleParameterNames[i]); particleVariables.insert(particleParameterNames[i]);
pairVariables.insert(particleParameterNames[i]+"1"); pairVariables.insert(particleParameterNames[i]+"1");
pairVariables.insert(particleParameterNames[i]+"2"); pairVariables.insert(particleParameterNames[i]+"2");
...@@ -1465,6 +1511,11 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1465,6 +1511,11 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
// Parse the expressions for energy terms. // Parse the expressions for energy terms.
energyExpressions.clear();
energyTypes.clear();
energyDerivExpressions.clear();
energyGradientExpressions.clear();
energyParamDerivExpressions.clear();
energyDerivExpressions.resize(force.getNumEnergyTerms()); energyDerivExpressions.resize(force.getNumEnergyTerms());
energyGradientExpressions.resize(force.getNumEnergyTerms()); energyGradientExpressions.resize(force.getNumEnergyTerms());
energyParamDerivExpressions.resize(force.getNumEnergyTerms()); energyParamDerivExpressions.resize(force.getNumEnergyTerms());
...@@ -1540,6 +1591,19 @@ void ReferenceCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& cont ...@@ -1540,6 +1591,19 @@ void ReferenceCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& cont
for (int j = 0; j < numParameters; j++) for (int j = 0; j < numParameters; j++)
particleParamArray[i][j] = parameters[j]; particleParamArray[i][j] = parameters[j];
} }
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed)
createExpressions(force);
} }
ReferenceCalcCustomExternalForceKernel::~ReferenceCalcCustomExternalForceKernel() { ReferenceCalcCustomExternalForceKernel::~ReferenceCalcCustomExternalForceKernel() {
...@@ -1637,8 +1701,7 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const ...@@ -1637,8 +1701,7 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
// Build the arrays. // Build the arrays.
vector<vector<int> > donorParticles(numDonors); donorParticles.resize(numDonors);
int numDonorParameters = force.getNumPerDonorParameters();
donorParamArray.resize(numDonors); donorParamArray.resize(numDonors);
for (int i = 0; i < numDonors; ++i) { for (int i = 0; i < numDonors; ++i) {
int d1, d2, d3; int d1, d2, d3;
...@@ -1647,8 +1710,7 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const ...@@ -1647,8 +1710,7 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
donorParticles[i].push_back(d2); donorParticles[i].push_back(d2);
donorParticles[i].push_back(d3); donorParticles[i].push_back(d3);
} }
vector<vector<int> > acceptorParticles(numAcceptors); acceptorParticles.resize(numAcceptors);
int numAcceptorParameters = force.getNumPerAcceptorParameters();
acceptorParamArray.resize(numAcceptors); acceptorParamArray.resize(numAcceptors);
for (int i = 0; i < numAcceptors; ++i) { for (int i = 0; i < numAcceptors; ++i) {
int a1, a2, a3; int a1, a2, a3;
...@@ -1657,13 +1719,25 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const ...@@ -1657,13 +1719,25 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
acceptorParticles[i].push_back(a2); acceptorParticles[i].push_back(a2);
acceptorParticles[i].push_back(a3); acceptorParticles[i].push_back(a3);
} }
NonbondedMethod nonbondedMethod = CalcCustomHbondForceKernel::NonbondedMethod(force.getNonbondedMethod()); for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
nonbondedCutoff = force.getCutoffDistance(); nonbondedCutoff = force.getCutoffDistance();
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
createInteraction(force);
}
void ReferenceCalcCustomHbondForceKernel::createInteraction(const CustomHbondForce& force) {
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Parse the expression and create the object used to calculate the interaction. // Parse the expression and create the object used to calculate the interaction.
...@@ -1674,13 +1748,12 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const ...@@ -1674,13 +1748,12 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
Lepton::ParsedExpression energyExpression = CustomHbondForceImpl::prepareExpression(force, functions, distances, angles, dihedrals); Lepton::ParsedExpression energyExpression = CustomHbondForceImpl::prepareExpression(force, functions, distances, angles, dihedrals);
vector<string> donorParameterNames; vector<string> donorParameterNames;
vector<string> acceptorParameterNames; vector<string> acceptorParameterNames;
for (int i = 0; i < numDonorParameters; i++) for (int i = 0; i < force.getNumPerDonorParameters(); i++)
donorParameterNames.push_back(force.getPerDonorParameterName(i)); donorParameterNames.push_back(force.getPerDonorParameterName(i));
for (int i = 0; i < numAcceptorParameters; i++) for (int i = 0; i < force.getNumPerAcceptorParameters(); i++)
acceptorParameterNames.push_back(force.getPerAcceptorParameterName(i)); acceptorParameterNames.push_back(force.getPerAcceptorParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
ixn = new ReferenceCustomHbondIxn(donorParticles, acceptorParticles, energyExpression, donorParameterNames, acceptorParameterNames, distances, angles, dihedrals); ixn = new ReferenceCustomHbondIxn(donorParticles, acceptorParticles, energyExpression, donorParameterNames, acceptorParameterNames, distances, angles, dihedrals);
NonbondedMethod nonbondedMethod = CalcCustomHbondForceKernel::NonbondedMethod(force.getNonbondedMethod());
isPeriodic = (nonbondedMethod == CutoffPeriodic); isPeriodic = (nonbondedMethod == CutoffPeriodic);
if (nonbondedMethod != NoCutoff) if (nonbondedMethod != NoCutoff)
ixn->setUseCutoff(nonbondedCutoff); ixn->setUseCutoff(nonbondedCutoff);
...@@ -1733,6 +1806,22 @@ void ReferenceCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& c ...@@ -1733,6 +1806,22 @@ void ReferenceCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& c
for (int j = 0; j < numAcceptorParameters; j++) for (int j = 0; j < numAcceptorParameters; j++)
acceptorParamArray[i][j] = parameters[j]; acceptorParamArray[i][j] = parameters[j];
} }
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
createInteraction(force);
}
} }
ReferenceCalcCustomCentroidBondForceKernel::~ReferenceCalcCustomCentroidBondForceKernel() { ReferenceCalcCustomCentroidBondForceKernel::~ReferenceCalcCustomCentroidBondForceKernel() {
...@@ -1746,23 +1835,32 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system ...@@ -1746,23 +1835,32 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system
// Build the arrays. // Build the arrays.
int numGroups = force.getNumGroups(); int numGroups = force.getNumGroups();
vector<vector<int> > groupAtoms(numGroups); groupAtoms.resize(numGroups);
vector<double> ignored; vector<double> ignored;
for (int i = 0; i < numGroups; i++) for (int i = 0; i < numGroups; i++)
force.getGroupParameters(i, groupAtoms[i], ignored); force.getGroupParameters(i, groupAtoms[i], ignored);
vector<vector<double> > normalizedWeights;
CustomCentroidBondForceImpl::computeNormalizedWeights(force, system, normalizedWeights); CustomCentroidBondForceImpl::computeNormalizedWeights(force, system, normalizedWeights);
numBonds = force.getNumBonds(); numBonds = force.getNumBonds();
vector<vector<int> > bondGroups(numBonds); bondGroups.resize(numBonds);
int numBondParameters = force.getNumPerBondParameters();
bondParamArray.resize(numBonds); bondParamArray.resize(numBonds);
for (int i = 0; i < numBonds; ++i) for (int i = 0; i < numBonds; ++i)
force.getBondParameters(i, bondGroups[i], bondParamArray[i]); force.getBondParameters(i, bondGroups[i], bondParamArray[i]);
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
createInteraction(force);
}
void ReferenceCalcCustomCentroidBondForceKernel::createInteraction(const CustomCentroidBondForce& force) {
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create implementations of point functions. // Create implementations of point functions.
...@@ -1773,9 +1871,10 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system ...@@ -1773,9 +1871,10 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system
// Parse the expression and create the object used to calculate the interaction. // Parse the expression and create the object used to calculate the interaction.
int numGroups = force.getNumGroups();
Lepton::ParsedExpression energyExpression = CustomCentroidBondForceImpl::prepareExpression(force, functions); Lepton::ParsedExpression energyExpression = CustomCentroidBondForceImpl::prepareExpression(force, functions);
vector<string> bondParameterNames; vector<string> bondParameterNames;
for (int i = 0; i < numBondParameters; i++) for (int i = 0; i < force.getNumPerBondParameters(); i++)
bondParameterNames.push_back(force.getPerBondParameterName(i)); bondParameterNames.push_back(force.getPerBondParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
...@@ -1830,6 +1929,22 @@ void ReferenceCalcCustomCentroidBondForceKernel::copyParametersToContext(Context ...@@ -1830,6 +1929,22 @@ void ReferenceCalcCustomCentroidBondForceKernel::copyParametersToContext(Context
for (int j = 0; j < numParameters; j++) for (int j = 0; j < numParameters; j++)
bondParamArray[i][j] = params[j]; bondParamArray[i][j] = params[j];
} }
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
createInteraction(force);
}
} }
ReferenceCalcCustomCompoundBondForceKernel::~ReferenceCalcCustomCompoundBondForceKernel() { ReferenceCalcCustomCompoundBondForceKernel::~ReferenceCalcCustomCompoundBondForceKernel() {
...@@ -1843,16 +1958,26 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system ...@@ -1843,16 +1958,26 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system
// Build the arrays. // Build the arrays.
numBonds = force.getNumBonds(); numBonds = force.getNumBonds();
vector<vector<int> > bondParticles(numBonds); bondParticles.resize(numBonds);
int numBondParameters = force.getNumPerBondParameters();
bondParamArray.resize(numBonds); bondParamArray.resize(numBonds);
for (int i = 0; i < numBonds; ++i) for (int i = 0; i < numBonds; ++i)
force.getBondParameters(i, bondParticles[i], bondParamArray[i]); force.getBondParameters(i, bondParticles[i], bondParamArray[i]);
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
createInteraction(force);
}
void ReferenceCalcCustomCompoundBondForceKernel::createInteraction(const CustomCompoundBondForce& force) {
// Create custom functions for the tabulated functions. // Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i)); functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create implementations of point functions. // Create implementations of point functions.
...@@ -1865,7 +1990,7 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system ...@@ -1865,7 +1990,7 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system
Lepton::ParsedExpression energyExpression = CustomCompoundBondForceImpl::prepareExpression(force, functions); Lepton::ParsedExpression energyExpression = CustomCompoundBondForceImpl::prepareExpression(force, functions);
vector<string> bondParameterNames; vector<string> bondParameterNames;
for (int i = 0; i < numBondParameters; i++) for (int i = 0; i < force.getNumPerBondParameters(); i++)
bondParameterNames.push_back(force.getPerBondParameterName(i)); bondParameterNames.push_back(force.getPerBondParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
...@@ -1920,6 +2045,22 @@ void ReferenceCalcCustomCompoundBondForceKernel::copyParametersToContext(Context ...@@ -1920,6 +2045,22 @@ void ReferenceCalcCustomCompoundBondForceKernel::copyParametersToContext(Context
for (int j = 0; j < numParameters; j++) for (int j = 0; j < numParameters; j++)
bondParamArray[i][j] = params[j]; bondParamArray[i][j] = params[j];
} }
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
createInteraction(force);
}
} }
ReferenceCalcCustomManyParticleForceKernel::~ReferenceCalcCustomManyParticleForceKernel() { ReferenceCalcCustomManyParticleForceKernel::~ReferenceCalcCustomManyParticleForceKernel() {
...@@ -1928,7 +2069,6 @@ ReferenceCalcCustomManyParticleForceKernel::~ReferenceCalcCustomManyParticleForc ...@@ -1928,7 +2069,6 @@ ReferenceCalcCustomManyParticleForceKernel::~ReferenceCalcCustomManyParticleForc
} }
void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system, const CustomManyParticleForce& force) { void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system, const CustomManyParticleForce& force) {
// Build the arrays. // Build the arrays.
numParticles = system.getNumParticles(); numParticles = system.getNumParticles();
...@@ -1939,6 +2079,14 @@ void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system ...@@ -1939,6 +2079,14 @@ void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system
} }
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
ixn = new ReferenceCustomManyParticleIxn(force); ixn = new ReferenceCustomManyParticleIxn(force);
nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod()); nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod());
cutoffDistance = force.getCutoffDistance(); cutoffDistance = force.getCutoffDistance();
...@@ -1977,6 +2125,22 @@ void ReferenceCalcCustomManyParticleForceKernel::copyParametersToContext(Context ...@@ -1977,6 +2125,22 @@ void ReferenceCalcCustomManyParticleForceKernel::copyParametersToContext(Context
for (int j = 0; j < numParameters; j++) for (int j = 0; j < numParameters; j++)
particleParamArray[i][j] = parameters[j]; particleParamArray[i][j] = parameters[j];
} }
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
ixn = new ReferenceCustomManyParticleIxn(force);
}
} }
ReferenceCalcGayBerneForceKernel::~ReferenceCalcGayBerneForceKernel() { ReferenceCalcGayBerneForceKernel::~ReferenceCalcGayBerneForceKernel() {
......
...@@ -205,6 +205,18 @@ void testComplexFunction(bool byGroups) { ...@@ -205,6 +205,18 @@ void testComplexFunction(bool byGroups) {
for (int i = 0; i < numParticles; i++) for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], TOL); ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], TOL);
} }
// Try updating the tabulated function.
for (int i = 0; i < table.size(); i++)
table[i] *= 0.5;
dynamic_cast<Continuous1DFunction&>(compound->getTabulatedFunction(0)).setFunctionParameters(table, -1, 10);
dynamic_cast<Continuous1DFunction&>(centroid->getTabulatedFunction(0)).setFunctionParameters(table, -1, 10);
compound->updateParametersInContext(context);
centroid->updateParametersInContext(context);
State state1 = context.getState(State::Energy, false, 1<<0);
State state2 = context.getState(State::Energy, false, 1<<1);
ASSERT_EQUAL_TOL(state1.getPotentialEnergy(), state2.getPotentialEnergy(), TOL);
} }
void testCustomWeights() { void testCustomWeights() {
......
...@@ -212,6 +212,31 @@ void testContinuous2DFunction() { ...@@ -212,6 +212,31 @@ void testContinuous2DFunction() {
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
} }
} }
// Try updating the tabulated function.
for (int i = 0; i < table.size(); i++)
table[i] *= 0.5;
Continuous2DFunction& fn = dynamic_cast<Continuous2DFunction&>(forceField->getTabulatedFunction(0));
fn.setFunctionParameters(xsize, ysize, table, xmin, xmax, ymin, ymax);
forceField->updateParametersInContext(context);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
positions[0] = Vec3(x, y, 1.5);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
Vec3 force(0, 0, 0);
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) {
energy = 0.5*sin(0.25*x)*cos(0.33*y)+1;
force[0] = 0.5*(-0.25*cos(0.25*x)*cos(0.33*y));
force[1] = 0.5*0.3*sin(0.25*x)*sin(0.33*y);
}
ASSERT_EQUAL_VEC(force, forces[0], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
} }
void testContinuous3DFunction() { void testContinuous3DFunction() {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -254,7 +254,7 @@ void testMembrane() { ...@@ -254,7 +254,7 @@ void testMembrane() {
double norm = 0.0; double norm = 0.0;
for (int i = 0; i < (int) forces.size(); ++i) for (int i = 0; i < (int) forces.size(); ++i)
norm += forces[i].dot(forces[i]); norm += forces[i].dot(forces[i]);
norm = std::sqrt(norm); norm = sqrt(norm);
const double stepSize = 1e-2; const double stepSize = 1e-2;
double step = 0.5*stepSize/norm; double step = 0.5*stepSize/norm;
vector<Vec3> positions2(numParticles), positions3(numParticles); vector<Vec3> positions2(numParticles), positions3(numParticles);
...@@ -283,7 +283,7 @@ void testTabulatedFunction() { ...@@ -283,7 +283,7 @@ void testTabulatedFunction() {
force->addParticle(vector<double>()); force->addParticle(vector<double>());
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(sin(0.25*i));
force->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0)); force->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(force); system.addForce(force);
Context context(system, integrator, platform); Context context(system, integrator, platform);
...@@ -296,8 +296,8 @@ void testTabulatedFunction() { ...@@ -296,8 +296,8 @@ void testTabulatedFunction() {
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy); State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces(); const vector<Vec3>& forces = state.getForces();
double force = (x < 1.0 || x > 6.0 ? 0.0 : -std::cos(x-1.0)); double force = (x < 1.0 || x > 6.0 ? 0.0 : -cos(x-1.0));
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1); ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
...@@ -308,7 +308,22 @@ void testTabulatedFunction() { ...@@ -308,7 +308,22 @@ void testTabulatedFunction() {
positions[1] = Vec3(x, 0, 0); positions[1] = Vec3(x, 0, 0);
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Energy); State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
// Try updating the tabulated function.
for (int i = 0; i < table.size(); i++)
table[i] *= 0.5;
dynamic_cast<Continuous1DFunction&>(force->getTabulatedFunction(0)).setFunctionParameters(table, 1.0, 6.0);
force->updateParametersInContext(context);
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : 0.5*sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
} }
} }
...@@ -385,7 +400,7 @@ void testPositionDependence() { ...@@ -385,7 +400,7 @@ void testPositionDependence() {
double norm = 0.0; double norm = 0.0;
for (int i = 0; i < (int) forces.size(); ++i) for (int i = 0; i < (int) forces.size(); ++i)
norm += forces[i].dot(forces[i]); norm += forces[i].dot(forces[i]);
norm = std::sqrt(norm); norm = sqrt(norm);
const double stepSize = 1e-3; const double stepSize = 1e-3;
double step = 0.5*stepSize/norm; double step = 0.5*stepSize/norm;
vector<Vec3> positions2(2), positions3(2); vector<Vec3> positions2(2), positions3(2);
...@@ -455,7 +470,7 @@ void testExclusions() { ...@@ -455,7 +470,7 @@ void testExclusions() {
double norm = 0.0; double norm = 0.0;
for (int i = 0; i < (int) forces.size(); ++i) for (int i = 0; i < (int) forces.size(); ++i)
norm += forces[i].dot(forces[i]); norm += forces[i].dot(forces[i]);
norm = std::sqrt(norm); norm = sqrt(norm);
if (norm > 0) { if (norm > 0) {
const double stepSize = 1e-3; const double stepSize = 1e-3;
double step = stepSize/norm; double step = stepSize/norm;
......
...@@ -223,6 +223,15 @@ void testCustomFunctions() { ...@@ -223,6 +223,15 @@ void testCustomFunctions() {
ASSERT_EQUAL_VEC(Vec3(0, -0.1, 0), forces[1], TOL); ASSERT_EQUAL_VEC(Vec3(0, -0.1, 0), forces[1], TOL);
ASSERT_EQUAL_VEC(Vec3(-0.1, 0, 0), forces[2], TOL); ASSERT_EQUAL_VEC(Vec3(-0.1, 0, 0), forces[2], TOL);
ASSERT_EQUAL_TOL(0.1*2+0.1*2, state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(0.1*2+0.1*2, state.getPotentialEnergy(), TOL);
// Try updating the tabulated function.
for (int i = 0; i < function.size(); i++)
function[i] *= 0.5;
dynamic_cast<Continuous1DFunction&>(custom->getTabulatedFunction(0)).setFunctionParameters(function, 0, 10);
custom->updateParametersInContext(context);
state = context.getState(State::Energy);
ASSERT_EQUAL_TOL(0.5*(0.1*2+0.1*2), state.getPotentialEnergy(), TOL);
} }
void test2DFunction() { void test2DFunction() {
......
...@@ -516,6 +516,15 @@ void testTabulatedFunctions() { ...@@ -516,6 +516,15 @@ void testTabulatedFunctions() {
expectedEnergy += 0.5*(r12+r13+r23)*(c[i]+c[j]+c[k]); expectedEnergy += 0.5*(r12+r13+r23)*(c[i]+c[j]+c[k]);
} }
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5); ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
// Try updating the tabulated function.
for (int i = 0; i < values.size(); i++)
values[i] *= 0.5;
dynamic_cast<Discrete3DFunction&>(force->getTabulatedFunction(1)).setFunctionParameters(numParticles, numParticles, numParticles, values);
force->updateParametersInContext(context);
state = context.getState(State::Energy);
ASSERT_EQUAL_TOL(0.5*expectedEnergy, state.getPotentialEnergy(), 1e-5);
} }
void testTypeFilters() { void testTypeFilters() {
......
...@@ -355,6 +355,21 @@ void testContinuous1DFunction() { ...@@ -355,6 +355,21 @@ void testContinuous1DFunction() {
double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
} }
// Try updating the tabulated function.
for (int i = 0; i < table.size(); i++)
table[i] *= 0.5;
dynamic_cast<Continuous1DFunction&>(forceField->getTabulatedFunction(0)).setFunctionParameters(table, 1.0, 6.0);
forceField->updateParametersInContext(context);
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : 0.5*sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
} }
void testPeriodicContinuous1DFunction() { void testPeriodicContinuous1DFunction() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment