Unverified Commit 27bcb657 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

updateParametersInContext() can change tabulated functions (#3307)

* updateParametersInContext() can change tabulated functions

* Fixed error in building C wrappers

* updateParametersInContext() can change tabulated functions for CustomCentroidBondForce

* CustomNonbondedForce can update tabulated functions

* CustomGBForce can update tabulated functions

* CustomManyParticleForce can update tabulated functions

* CustomHbondForce can update tabulated functions
parent d83c2724
......@@ -358,15 +358,16 @@ public:
*/
const std::string& getTabulatedFunctionName(int index) const;
/**
* Update the per-bond parameters in a Context to match those stored in this Force object. This method provides
* Update the per-bond parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setBondParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the values of per-bond parameters.
* All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* This method has several limitations. The only information it updates is the values of per-bond parameters and tabulated
* functions. All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. Neither the definitions of groups nor the set of groups involved in a bond can be changed, nor can new
* bonds be added.
* bonds be added. Also, while the tabulated values of a function can change, everything else about it (its dimensions,
* the data range) must not be changed.
*/
void updateParametersInContext(Context& context);
/**
......
......@@ -338,14 +338,15 @@ public:
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Update the per-bond parameters in a Context to match those stored in this Force object. This method provides
* Update the per-bond parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setBondParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the values of per-bond parameters.
* All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. The set of particles involved in a bond cannot be changed, nor can new bonds be added.
* This method has several limitations. The only information it updates is the values of per-bond parameters and tabulated
* functions. All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. The set of particles involved in a bond cannot be changed, nor can new bonds be added. Also, while the
* tabulated values of a function can change, everything else about it (its dimensions, the data range) must not be changed.
*/
void updateParametersInContext(Context& context);
/**
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -544,14 +544,15 @@ public:
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Update the per-particle parameters in a Context to match those stored in this Force object. This method provides
* Update the per-particle parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the values of per-particle parameters.
* All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. Also, this method cannot be used to add new particles, only to change the parameters of existing ones.
* This method has several limitations. The only information it updates is the values of per-particle parameters and tabulated
* functions. All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. Also, this method cannot be used to add new particles, only to change the parameters of existing ones. While
* the tabulated values of a function can change, everything else about it (its dimensions, the data range) must not be changed.
*/
void updateParametersInContext(Context& context);
/**
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -443,15 +443,16 @@ public:
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Update the per-donor and per-acceptor parameters in a Context to match those stored in this Force object. This method
* Update the per-donor and per-acceptor parameters and tabulated functions in a Context to match those stored in this Force object. This method
* provides an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setDonorParameters() and setAcceptorParameters() to modify this object's parameters, then call
* updateParametersInContext() to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the values of per-donor and per-acceptor parameters.
* All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can only
* This method has several limitations. The only information it updates is the values of per-donor and per-acceptor parameters and tabulated
* functions. All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can only
* be changed by reinitializing the Context. The set of particles involved in a donor or acceptor cannot be changed, nor can
* new donors or acceptors be added.
* new donors or acceptors be added. While the tabulated values of a function can change, everything else about it (its dimensions,
* the data range) must not be changed.
*/
void updateParametersInContext(Context& context);
/**
......
......@@ -480,15 +480,16 @@ public:
*/
const std::string& getTabulatedFunctionName(int index) const;
/**
* Update the per-particle parameters in a Context to match those stored in this Force object. This method provides
* Update the per-particle parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the values of per-particle parameters.
* All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can
* This method has several limitations. The only information it updates is the values of per-particle parameters and tabulated
* functions. All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can
* only be changed by reinitializing the Context. Also, this method cannot be used to add new particles, only to change
* the parameters of existing ones.
* the parameters of existing ones. While the tabulated values of a function can change, everything else about it (its dimensions,
* the data range) must not be changed.
*/
void updateParametersInContext(Context& context);
/**
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -490,15 +490,16 @@ public:
*/
void setInteractionGroupParameters(int index, const std::set<int>& set1, const std::set<int>& set2);
/**
* Update the per-particle parameters in a Context to match those stored in this Force object. This method provides
* Update the per-particle parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the values of per-particle parameters.
* All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can
* This method has several limitations. The only information it updates is the values of per-particle parameters and tabulated
* functions. All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can
* only be changed by reinitializing the Context. Also, this method cannot be used to add new particles, only to change
* the parameters of existing ones.
* the parameters of existing ones. While the tabulated values of a function can change, everything else about it (its dimensions,
* the data range) must not be changed.
*/
void updateParametersInContext(Context& context);
/**
......
......@@ -65,9 +65,12 @@ public:
virtual TabulatedFunction* Copy() const = 0;
/**
* Get the periodicity status of the tabulated function.
*
*/
bool getPeriodic() const;
virtual bool operator==(const TabulatedFunction& other) const = 0;
virtual bool operator!=(const TabulatedFunction& other) const {
return !(*this == other);
}
protected:
bool periodic;
};
......@@ -114,6 +117,7 @@ public:
* @deprecated This will be removed in a future release.
*/
Continuous1DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private:
std::vector<double> values;
double min, max;
......@@ -176,6 +180,7 @@ public:
* @deprecated This will be removed in a future release.
*/
Continuous2DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private:
std::vector<double> values;
int xsize, ysize;
......@@ -254,6 +259,7 @@ public:
* @deprecated This will be removed in a future release.
*/
Continuous3DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private:
std::vector<double> values;
int xsize, ysize, zsize;
......@@ -291,6 +297,7 @@ public:
* @deprecated This will be removed in a future release.
*/
Discrete1DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private:
std::vector<double> values;
};
......@@ -335,6 +342,7 @@ public:
* @deprecated This will be removed in a future release.
*/
Discrete2DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private:
int xsize, ysize;
std::vector<double> values;
......@@ -383,6 +391,7 @@ public:
* @deprecated This will be removed in a future release.
*/
Discrete3DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private:
int xsize, ysize, zsize;
std::vector<double> values;
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014 Stanford University and the Authors. *
* Portions copyright (c) 2014-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -75,6 +75,15 @@ Continuous1DFunction* Continuous1DFunction::Copy() const {
return new Continuous1DFunction(new_vec, min, max);
}
bool Continuous1DFunction::operator==(const TabulatedFunction& other) const {
const Continuous1DFunction* fn = dynamic_cast<const Continuous1DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->min != min || fn->max != max)
return false;
return (fn->values == values);
}
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, bool periodic) {
this->periodic = periodic;
setFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
......@@ -120,6 +129,19 @@ Continuous2DFunction* Continuous2DFunction::Copy() const {
return new Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax);
}
bool Continuous2DFunction::operator==(const TabulatedFunction& other) const {
const Continuous2DFunction* fn = dynamic_cast<const Continuous2DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->xsize != xsize || fn->ysize != ysize)
return false;
if (fn->xmin != xmin || fn->xmax != xmax)
return false;
if (fn->ymin != ymin || fn->ymax != ymax)
return false;
return (fn->values == values);
}
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax, bool periodic) {
this->periodic = periodic;
setFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
......@@ -173,6 +195,20 @@ Continuous3DFunction* Continuous3DFunction::Copy() const {
return new Continuous3DFunction(xsize, ysize, zsize, new_vec, xmin, xmax, ymin, ymax, zmin, zmax);
}
bool Continuous3DFunction::operator==(const TabulatedFunction& other) const {
const Continuous3DFunction* fn = dynamic_cast<const Continuous3DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->xsize != xsize || fn->ysize != ysize || fn->zsize != zsize)
return false;
if (fn->xmin != xmin || fn->xmax != xmax)
return false;
if (fn->ymin != ymin || fn->ymax != ymax)
return false;
if (fn->zmin != zmin || fn->zmax != zmax)
return false;
return (fn->values == values);
}
Discrete1DFunction::Discrete1DFunction(const vector<double>& values) {
this->values = values;
......@@ -193,6 +229,13 @@ Discrete1DFunction* Discrete1DFunction::Copy() const {
return new Discrete1DFunction(new_vec);
}
bool Discrete1DFunction::operator==(const TabulatedFunction& other) const {
const Discrete1DFunction* fn = dynamic_cast<const Discrete1DFunction*>(&other);
if (fn == NULL)
return false;
return (fn->values == values);
}
Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const vector<double>& values) {
if (values.size() != xsize*ysize)
throw OpenMMException("Discrete2DFunction: incorrect number of values");
......@@ -222,6 +265,15 @@ Discrete2DFunction* Discrete2DFunction::Copy() const {
return new Discrete2DFunction(xsize, ysize, new_vec);
}
bool Discrete2DFunction::operator==(const TabulatedFunction& other) const {
const Discrete2DFunction* fn = dynamic_cast<const Discrete2DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->xsize != xsize || fn->ysize != ysize)
return false;
return (fn->values == values);
}
Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const vector<double>& values) {
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Discrete3DFunction: incorrect number of values");
......@@ -253,3 +305,12 @@ Discrete3DFunction* Discrete3DFunction::Copy() const {
new_vec[i] = values[i];
return new Discrete3DFunction(xsize, ysize, zsize, new_vec);
}
bool Discrete3DFunction::operator==(const TabulatedFunction& other) const {
const Discrete3DFunction* fn = dynamic_cast<const Discrete3DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->xsize != xsize || fn->ysize != ysize || fn->zsize != zsize)
return false;
return (fn->values == values);
}
......@@ -529,7 +529,8 @@ private:
ComputeArray globals;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
const System& system;
};
......@@ -577,7 +578,8 @@ private:
ComputeArray groupForces, bondGroups, centerPositions;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::vector<void*> groupForcesArgs;
ComputeKernel computeCentersKernel, groupForcesKernel, applyForcesKernel;
const System& system;
......@@ -628,7 +630,8 @@ private:
std::vector<void*> interactionGroupArgs;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
double longRangeCoefficient;
std::vector<double> longRangeCoefficientDerivs;
bool hasInitializedLongRangeCorrection, hasInitializedKernel, hasParamDerivs, useNeighborList;
......@@ -728,7 +731,8 @@ private:
ComputeArray longEnergyDerivs, globals, valueBuffers;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::vector<bool> pairValueUsesParam, pairEnergyUsesParam, pairEnergyUsesValue;
const System& system;
ComputeKernel pairValueKernel, perParticleValueKernel, pairEnergyKernel, perParticleEnergyKernel, gradientChainRuleKernel;
......@@ -785,7 +789,8 @@ private:
ComputeArray acceptorExclusions;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
const System& system;
ComputeKernel donorKernel, acceptorKernel;
};
......@@ -836,7 +841,8 @@ private:
ComputeArray neighborPairs, numNeighborPairs, neighborStartIndex, numNeighborsForAtom, neighbors;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
const System& system;
ComputeKernel forceKernel, blockBoundsKernel, neighborsKernel, startIndicesKernel, copyPairsKernel;
ComputeEvent event;
......
......@@ -35,6 +35,7 @@
#include "openmm/internal/CustomCompoundBondForceImpl.h"
#include "openmm/internal/CustomHbondForceImpl.h"
#include "openmm/internal/CustomManyParticleForceImpl.h"
#include "openmm/serialization/XmlSerializer.h"
#include "CommonKernelSources.h"
#include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h"
......@@ -1289,16 +1290,17 @@ void CommonCalcCustomCompoundBondForceKernel::initialize(const System& system, c
map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList;
tabulatedFunctions.resize(force.getNumTabulatedFunctions());
tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctions[i].upload(f);
string arrayName = cc.getBondedUtilities().addArgument(tabulatedFunctions[i], width == 1 ? "float" : "float"+cc.intToString(width));
tabulatedFunctionArrays[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctionArrays[i].upload(f);
string arrayName = cc.getBondedUtilities().addArgument(tabulatedFunctionArrays[i], width == 1 ? "float" : "float"+cc.intToString(width));
functionDefinitions.push_back(make_pair(name, arrayName));
}
......@@ -1397,9 +1399,9 @@ void CommonCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImp
throw OpenMMException("updateParametersInContext: The number of bonds has changed");
if (numBonds == 0)
return;
// Record the per-bond parameters.
vector<vector<float> > paramVector(numBonds);
vector<int> particles;
vector<double> parameters;
......@@ -1410,9 +1412,21 @@ void CommonCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImp
paramVector[i][j] = (float) parameters[j];
}
params->setParameterValues(paramVector);
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid.
cc.invalidateMolecules(info);
}
......@@ -1535,17 +1549,18 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c
vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList;
stringstream extraArgs;
tabulatedFunctions.resize(force.getNumTabulatedFunctions());
tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctions[i].upload(f);
tabulatedFunctionArrays[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctionArrays[i].upload(f);
extraArgs << ", GLOBAL const float";
if (width > 1)
extraArgs << width;
......@@ -1667,7 +1682,7 @@ void CommonCalcCustomCentroidBondForceKernel::initialize(const System& system, c
groupForcesKernel->addArg(); // Periodic box information will be set just before it is executed.
if (needEnergyParamDerivs)
groupForcesKernel->addArg(); // Deriv buffer hasn't been created yet.
for (auto& function : tabulatedFunctions)
for (auto& function : tabulatedFunctionArrays)
groupForcesKernel->addArg(function);
if (globals.isInitialized())
groupForcesKernel->addArg(globals);
......@@ -1714,9 +1729,9 @@ void CommonCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImp
throw OpenMMException("updateParametersInContext: The number of bonds has changed");
if (numBonds == 0)
return;
// Record the per-bond parameters.
vector<vector<float> > paramVector(numBonds);
vector<int> particles;
vector<double> parameters;
......@@ -1727,9 +1742,21 @@ void CommonCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImp
paramVector[i][j] = (float) parameters[j];
}
params->setParameterValues(paramVector);
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid.
cc.invalidateMolecules(info);
}
......@@ -1868,18 +1895,19 @@ void CommonCalcCustomNonbondedForceKernel::initialize(const System& system, cons
vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList;
vector<string> tableTypes;
tabulatedFunctions.resize(force.getNumTabulatedFunctions());
tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
string arrayName = prefix+"table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctions[i].upload(f);
cc.getNonbondedUtilities().addArgument(ComputeParameterInfo(tabulatedFunctions[i], arrayName, "float", width));
tabulatedFunctionArrays[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctionArrays[i].upload(f);
cc.getNonbondedUtilities().addArgument(ComputeParameterInfo(tabulatedFunctionArrays[i], arrayName, "float", width));
if (width == 1)
tableTypes.push_back("float");
else
......@@ -2166,7 +2194,7 @@ void CommonCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNon
stringstream args;
for (int i = 0; i < (int) buffers.size(); i++)
args<<", GLOBAL const "<<buffers[i].getType()<<"* RESTRICT global_params"<<(i+1);
for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
for (int i = 0; i < (int) tabulatedFunctionArrays.size(); i++)
args << ", GLOBAL const " << tableTypes[i]<< "* RESTRICT table" << i;
if (globals.isInitialized())
args<<", GLOBAL const float* RESTRICT globals";
......@@ -2289,7 +2317,7 @@ double CommonCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool
interactionGroupKernel->addArg(); // Periodic box information will be set just before it is executed.
for (auto& parameter : params->getParameterInfos())
interactionGroupKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctions)
for (auto& function : tabulatedFunctionArrays)
interactionGroupKernel->addArg(function);
if (globals.isInitialized())
interactionGroupKernel->addArg(globals);
......@@ -2342,18 +2370,30 @@ void CommonCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl&
paramVector[i][j] = (float) parameters[j];
}
params->setParameterValues(paramVector);
// If necessary, recompute the long range correction.
if (forceCopy != NULL) {
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force);
CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, cc.getThreadPool());
hasInitializedLongRangeCorrection = false;
*forceCopy = force;
}
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid.
cc.invalidateMolecules(info);
}
......@@ -2679,18 +2719,19 @@ void CommonCalcCustomGBForceKernel::initialize(const System& system, const Custo
vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList;
stringstream tableArgs;
tabulatedFunctions.resize(force.getNumTabulatedFunctions());
tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
string arrayName = prefix+"table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctions[i].upload(f);
nb.addArgument(ComputeParameterInfo(tabulatedFunctions[i], arrayName, "float", width));
tabulatedFunctionArrays[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctionArrays[i].upload(f);
nb.addArgument(ComputeParameterInfo(tabulatedFunctionArrays[i], arrayName, "float", width));
tableArgs << ", GLOBAL const float";
if (width > 1)
tableArgs << width;
......@@ -3510,7 +3551,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
}
for (auto& d : dValue0dParam)
pairValueKernel->addArg(d);
for (auto& function : tabulatedFunctions)
for (auto& function : tabulatedFunctionArrays)
pairValueKernel->addArg(function);
perParticleValueKernel->addArg(cc.getPosq());
perParticleValueKernel->addArg(valueBuffers);
......@@ -3529,7 +3570,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
for (int j = 0; j < dValuedParam[i]->getParameterInfos().size(); j++)
perParticleValueKernel->addArg(dValuedParam[i]->getParameterInfos()[j].getArray());
}
for (auto& function : tabulatedFunctions)
for (auto& function : tabulatedFunctionArrays)
perParticleValueKernel->addArg(function);
pairEnergyKernel->addArg(useLong ? cc.getLongForceBuffer() : cc.getForceBuffers());
pairEnergyKernel->addArg(cc.getEnergyBuffer());
......@@ -3570,7 +3611,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
pairEnergyKernel->addArg(buffer.getArray());
if (needEnergyParamDerivs)
pairEnergyKernel->addArg(cc.getEnergyParamDerivBuffer());
for (auto& function : tabulatedFunctions)
for (auto& function : tabulatedFunctionArrays)
pairEnergyKernel->addArg(function);
perParticleEnergyKernel->addArg(cc.getEnergyBuffer());
perParticleEnergyKernel->addArg(cc.getPosq());
......@@ -3595,7 +3636,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
perParticleEnergyKernel->addArg(longEnergyDerivs);
if (needEnergyParamDerivs)
perParticleEnergyKernel->addArg(cc.getEnergyParamDerivBuffer());
for (auto& function : tabulatedFunctions)
for (auto& function : tabulatedFunctionArrays)
perParticleEnergyKernel->addArg(function);
if (needParameterGradient || needEnergyParamDerivs) {
gradientChainRuleKernel->addArg(cc.getPosq());
......@@ -3614,7 +3655,7 @@ double CommonCalcCustomGBForceKernel::execute(ContextImpl& context, bool include
for (auto& buffer : d->getParameterInfos())
gradientChainRuleKernel->addArg(buffer.getArray());
}
for (auto& function : tabulatedFunctions)
for (auto& function : tabulatedFunctionArrays)
gradientChainRuleKernel->addArg(function);
}
}
......@@ -3653,9 +3694,9 @@ void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context
int numParticles = force.getNumParticles();
if (numParticles != cc.getNumAtoms())
throw OpenMMException("updateParametersInContext: The number of particles has changed");
// Record the per-particle parameters.
vector<vector<float> > paramVector(cc.getPaddedNumAtoms(), vector<float>(force.getNumPerParticleParameters(), 0));
vector<double> parameters;
for (int i = 0; i < numParticles; i++) {
......@@ -3664,9 +3705,21 @@ void CommonCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context
paramVector[i][j] = (float) parameters[j];
}
params->setParameterValues(paramVector);
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid.
cc.invalidateMolecules(info);
}
......@@ -3880,17 +3933,18 @@ void CommonCalcCustomHbondForceKernel::initialize(const System& system, const Cu
vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList;
stringstream tableArgs;
tabulatedFunctions.resize(force.getNumTabulatedFunctions());
tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctions[i].upload(f);
tabulatedFunctionArrays[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctionArrays[i].upload(f);
tableArgs << ", GLOBAL const float";
if (width > 1)
tableArgs << width;
......@@ -4132,7 +4186,7 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl
donorKernel->addArg(parameter.getArray());
for (auto& parameter : acceptorParams->getParameterInfos())
donorKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctions)
for (auto& function : tabulatedFunctionArrays)
donorKernel->addArg(function);
if (cc.getSupports64BitGlobalAtomics())
acceptorKernel->addArg(cc.getLongForceBuffer());
......@@ -4153,7 +4207,7 @@ double CommonCalcCustomHbondForceKernel::execute(ContextImpl& context, bool incl
acceptorKernel->addArg(parameter.getArray());
for (auto& parameter : acceptorParams->getParameterInfos())
acceptorKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctions)
for (auto& function : tabulatedFunctionArrays)
acceptorKernel->addArg(function);
}
setPeriodicBoxArgs(cc, donorKernel, cc.getSupports64BitGlobalAtomics() ? 6 : 7);
......@@ -4172,9 +4226,9 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont
throw OpenMMException("updateParametersInContext: The number of donors has changed");
if (numAcceptors != force.getNumAcceptors())
throw OpenMMException("updateParametersInContext: The number of acceptors has changed");
// Record the per-donor parameters.
if (numDonors > 0) {
vector<vector<float> > donorParamVector(numDonors);
vector<double> parameters;
......@@ -4187,9 +4241,9 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont
}
donorParams->setParameterValues(donorParamVector);
}
// Record the per-acceptor parameters.
if (numAcceptors > 0) {
vector<vector<float> > acceptorParamVector(numAcceptors);
vector<double> parameters;
......@@ -4202,9 +4256,21 @@ void CommonCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& cont
}
acceptorParams->setParameterValues(acceptorParamVector);
}
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid.
cc.invalidateMolecules(info);
}
......@@ -4280,17 +4346,18 @@ void CommonCalcCustomManyParticleForceKernel::initialize(const System& system, c
vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList;
stringstream tableArgs;
tabulatedFunctions.resize(force.getNumTabulatedFunctions());
tabulatedFunctionArrays.resize(force.getNumTabulatedFunctions());
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
string arrayName = "table"+cc.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = cc.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctions[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctions[i].upload(f);
tabulatedFunctionArrays[i].initialize<float>(cc, f.size(), "TabulatedFunction");
tabulatedFunctionArrays[i].upload(f);
tableArgs << ", GLOBAL const float";
if (width > 1)
tableArgs << width;
......@@ -4593,7 +4660,7 @@ double CommonCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bo
forceKernel->addArg(globals);
for (auto& parameter : params->getParameterInfos())
forceKernel->addArg(parameter.getArray());
for (auto& function : tabulatedFunctions)
for (auto& function : tabulatedFunctionArrays)
forceKernel->addArg(function);
if (nonbondedMethod != NoCutoff) {
......@@ -4695,9 +4762,9 @@ void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImp
int numParticles = force.getNumParticles();
if (numParticles != cc.getNumAtoms())
throw OpenMMException("updateParametersInContext: The number of particles has changed");
// Record the per-particle parameters.
vector<vector<float> > paramVector(numParticles);
vector<double> parameters;
int type;
......@@ -4708,9 +4775,21 @@ void CommonCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImp
paramVector[i][j] = (float) parameters[j];
}
params->setParameterValues(paramVector);
// See if any tabulated functions have changed.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
int width;
vector<float> f = cc.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
tabulatedFunctionArrays[i].upload(f);
}
}
// Mark that the current reordering may be invalid.
cc.invalidateMolecules(info);
}
......
......@@ -320,7 +320,8 @@ public:
* @param force the CustomNonbondedForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force);
private:
private:
void createInteraction(const CustomNonbondedForce& force);
CpuPlatform::PlatformData& data;
int numParticles;
std::vector<std::vector<double> > particleParamArray;
......@@ -333,6 +334,7 @@ private:
std::vector<std::string> parameterNames, globalParameterNames, energyParamDerivNames;
std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups;
std::vector<double> longRangeCoefficientDerivs;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod;
CpuCustomNonbondedForce* nonbonded;
};
......@@ -410,6 +412,7 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomGBForce& force);
private:
void createInteraction(const CustomGBForce& force);
CpuPlatform::PlatformData& data;
int numParticles;
bool isPeriodic;
......@@ -421,6 +424,7 @@ private:
std::vector<std::string> particleParameterNames, globalParameterNames, energyParamDerivNames, valueNames;
std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes;
std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod;
};
......@@ -463,6 +467,7 @@ private:
std::vector<std::vector<double> > particleParamArray;
CpuCustomManyParticleForce* ixn;
std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod;
};
......
......@@ -45,6 +45,7 @@
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/vectorize.h"
#include "openmm/serialization/XmlSerializer.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CustomFunction.h"
#include "lepton/Operation.h"
......@@ -868,7 +869,6 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
// Build the arrays.
int numParameters = force.getNumPerParticleParameters();
particleParamArray.resize(numParticles);
for (int i = 0; i < numParticles; ++i)
force.getParticleParameters(i, particleParamArray[i]);
......@@ -882,10 +882,41 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
switchingDistance = force.getSwitchingDistance();
}
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Record information for the long range correction.
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) {
forceCopy = new CustomNonbondedForce(force);
hasInitializedLongRangeCorrection = false;
}
else {
longRangeCoefficient = 0.0;
hasInitializedLongRangeCorrection = true;
}
// Record the interaction groups.
for (int i = 0; i < force.getNumInteractionGroups(); i++) {
set<int> set1, set2;
force.getInteractionGroupParameters(i, set1, set2);
interactionGroups.push_back(make_pair(set1, set2));
}
data.isPeriodic |= (nonbondedMethod == CutoffPeriodic);
// Create the interaction.
createInteraction(force);
}
void CpuCalcCustomNonbondedForceKernel::createInteraction(const CustomNonbondedForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++)
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Parse the various expressions used to calculate the force.
......@@ -893,7 +924,7 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
Lepton::CompiledExpression energyExpression = expression.createCompiledExpression();
Lepton::CompiledExpression forceExpression = expression.differentiate("r").createCompiledExpression();
for (int i = 0; i < numParameters; i++)
for (int i = 0; i < force.getNumPerParticleParameters(); i++)
parameterNames.push_back(force.getPerParticleParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParameterNames.push_back(force.getGlobalParameterName(i));
......@@ -907,7 +938,7 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
}
set<string> variables;
variables.insert("r");
for (int i = 0; i < numParameters; i++) {
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
variables.insert(parameterNames[i]+"1");
variables.insert(parameterNames[i]+"2");
}
......@@ -918,26 +949,9 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
for (auto& function : functions)
delete function.second;
// Record information for the long range correction.
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) {
forceCopy = new CustomNonbondedForce(force);
hasInitializedLongRangeCorrection = false;
}
else {
longRangeCoefficient = 0.0;
hasInitializedLongRangeCorrection = true;
}
// Record the interaction groups.
for (int i = 0; i < force.getNumInteractionGroups(); i++) {
set<int> set1, set2;
force.getInteractionGroupParameters(i, set1, set2);
interactionGroups.push_back(make_pair(set1, set2));
}
data.isPeriodic |= (nonbondedMethod == CutoffPeriodic);
// Create the object that computes the interaction.
nonbonded = new CpuCustomNonbondedForce(energyExpression, forceExpression, parameterNames, exclusions, energyParamDerivExpressions, data.threads);
if (interactionGroups.size() > 0)
nonbonded->setInteractionGroups(interactionGroups);
......@@ -1011,6 +1025,22 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con
hasInitializedLongRangeCorrection = true;
*forceCopy = force;
}
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete nonbonded;
nonbonded = NULL;
createInteraction(force);
}
}
CpuCalcGBSAOBCForceKernel::~CpuCalcGBSAOBCForceKernel() {
......@@ -1101,11 +1131,10 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
// Build the arrays.
int numPerParticleParameters = force.getNumPerParticleParameters();
particleParamArray.resize(numParticles);
for (int i = 0; i < numParticles; ++i)
force.getParticleParameters(i, particleParamArray[i]);
for (int i = 0; i < numPerParticleParameters; i++)
for (int i = 0; i < force.getNumPerParticleParameters(); i++)
particleParameterNames.push_back(force.getPerParticleParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
......@@ -1113,15 +1142,30 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
nonbondedCutoff = force.getCutoffDistance();
if (nonbondedMethod != NoCutoff)
neighborList = new CpuNeighborList(4);
data.isPeriodic |= (force.getNonbondedMethod() == CustomGBForce::CutoffPeriodic);
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
createInteraction(force);
}
void CpuCalcCustomGBForceKernel::createInteraction(const CustomGBForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++)
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Parse the expressions for computed values.
valueTypes.clear();
valueNames.clear();
energyParamDerivNames.clear();
vector<vector<Lepton::CompiledExpression> > valueDerivExpressions(force.getNumComputedValues());
vector<vector<Lepton::CompiledExpression> > valueGradientExpressions(force.getNumComputedValues());
vector<vector<Lepton::CompiledExpression> > valueParamDerivExpressions(force.getNumComputedValues());
......@@ -1132,7 +1176,7 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
particleVariables.insert("x");
particleVariables.insert("y");
particleVariables.insert("z");
for (int i = 0; i < numPerParticleParameters; i++) {
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
particleVariables.insert(particleParameterNames[i]);
pairVariables.insert(particleParameterNames[i]+"1");
pairVariables.insert(particleParameterNames[i]+"2");
......@@ -1171,6 +1215,7 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
// Parse the expressions for energy terms.
energyTypes.clear();
vector<vector<Lepton::CompiledExpression> > energyDerivExpressions(force.getNumEnergyTerms());
vector<vector<Lepton::CompiledExpression> > energyGradientExpressions(force.getNumEnergyTerms());
vector<vector<Lepton::CompiledExpression> > energyParamDerivExpressions(force.getNumEnergyTerms());
......@@ -1208,7 +1253,6 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
ixn = new CpuCustomGBForce(numParticles, exclusions, valueExpressions, valueDerivExpressions, valueGradientExpressions, valueParamDerivExpressions,
valueNames, valueTypes, energyExpressions, energyDerivExpressions, energyGradientExpressions, energyParamDerivExpressions, energyTypes,
particleParameterNames, data.threads);
data.isPeriodic |= (force.getNonbondedMethod() == CustomGBForce::CutoffPeriodic);
}
double CpuCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
......@@ -1247,6 +1291,22 @@ void CpuCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, c
for (int j = 0; j < numParameters; j++)
particleParamArray[i][j] = static_cast<double>(parameters[j]);
}
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
createInteraction(force);
}
}
CpuCalcCustomManyParticleForceKernel::~CpuCalcCustomManyParticleForceKernel() {
......@@ -1266,6 +1326,14 @@ void CpuCalcCustomManyParticleForceKernel::initialize(const System& system, cons
}
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
ixn = new CpuCustomManyParticleForce(force, data.threads);
nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod());
cutoffDistance = force.getCutoffDistance();
......@@ -1303,6 +1371,22 @@ void CpuCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl&
for (int j = 0; j < numParameters; j++)
particleParamArray[i][j] = static_cast<double>(parameters[j]);
}
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
ixn = new CpuCustomManyParticleForce(force, data.threads);
}
}
CpuCalcGayBerneForceKernel::~CpuCalcGayBerneForceKernel() {
......
......@@ -682,6 +682,7 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force);
private:
void createExpressions(const CustomNonbondedForce& force);
int numParticles;
std::vector<std::vector<double> > particleParamArray;
double nonbondedCutoff, switchingDistance, periodicBoxSize[3], longRangeCoefficient;
......@@ -695,6 +696,7 @@ private:
std::vector<std::string> parameterNames, globalParameterNames, energyParamDerivNames;
std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups;
std::vector<double> longRangeCoefficientDerivs;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod;
NeighborList* neighborList;
};
......@@ -768,6 +770,7 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomGBForce& force);
private:
void createExpressions(const CustomGBForce& force);
int numParticles;
bool isPeriodic;
std::vector<std::vector<double> > particleParamArray;
......@@ -784,6 +787,7 @@ private:
std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyParamDerivExpressions;
std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod;
NeighborList* neighborList;
};
......@@ -861,13 +865,16 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomHbondForce& force);
private:
void createInteraction(const CustomHbondForce& force);
int numDonors, numAcceptors, numParticles;
bool isPeriodic;
std::vector<std::vector<int> > donorParticles, acceptorParticles;
std::vector<std::vector<double> > donorParamArray, acceptorParamArray;
double nonbondedCutoff;
ReferenceCustomHbondIxn* ixn;
std::vector<std::set<int> > exclusions;
std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
};
/**
......@@ -902,10 +909,15 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomCentroidBondForce& force);
private:
void createInteraction(const CustomCentroidBondForce& force);
int numBonds, numParticles;
std::vector<std::vector<int> > bondGroups;
std::vector<std::vector<int> > groupAtoms;
std::vector<std::vector<double> > normalizedWeights;
std::vector<std::vector<double> > bondParamArray;
ReferenceCustomCentroidBondIxn* ixn;
std::vector<std::string> globalParameterNames, energyParamDerivNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
bool usePeriodic;
Vec3* boxVectors;
};
......@@ -942,10 +954,13 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomCompoundBondForce& force);
private:
void createInteraction(const CustomCompoundBondForce& force);
int numBonds;
std::vector<std::vector<int> > bondParticles;
std::vector<std::vector<double> > bondParamArray;
ReferenceCustomCompoundBondIxn* ixn;
std::vector<std::string> globalParameterNames, energyParamDerivNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
bool usePeriodic;
Vec3* boxVectors;
};
......@@ -987,6 +1002,7 @@ private:
std::vector<std::vector<double> > particleParamArray;
ReferenceCustomManyParticleIxn* ixn;
std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod;
};
......
......@@ -80,6 +80,7 @@
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/Integrator.h"
#include "openmm/OpenMMException.h"
#include "openmm/serialization/XmlSerializer.h"
#include "SimTKOpenMMUtilities.h"
#include "lepton/CustomFunction.h"
#include "lepton/Operation.h"
......@@ -1151,7 +1152,6 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
// Build the arrays.
int numParameters = force.getNumPerParticleParameters();
particleParamArray.resize(numParticles);
for (int i = 0; i < numParticles; ++i)
force.getParticleParameters(i, particleParamArray[i]);
......@@ -1167,10 +1167,40 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
switchingDistance = force.getSwitchingDistance();
}
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the expressions.
createExpressions(force);
// Record information for the long range correction.
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) {
forceCopy = new CustomNonbondedForce(force);
hasInitializedLongRangeCorrection = false;
}
else {
longRangeCoefficient = 0.0;
hasInitializedLongRangeCorrection = true;
}
// Record the interaction groups.
for (int i = 0; i < force.getNumInteractionGroups(); i++) {
set<int> set1, set2;
force.getInteractionGroupParameters(i, set1, set2);
interactionGroups.push_back(make_pair(set1, set2));
}
}
void ReferenceCalcCustomNonbondedForceKernel::createExpressions(const CustomNonbondedForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++)
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Parse the various expressions used to calculate the force.
......@@ -1178,7 +1208,12 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
energyExpression = expression.createCompiledExpression();
forceExpression = expression.differentiate("r").createCompiledExpression();
for (int i = 0; i < numParameters; i++)
parameterNames.clear();
globalParameterNames.clear();
globalParamValues.clear();
energyParamDerivNames.clear();
energyParamDerivExpressions.clear();
for (int i = 0; i < force.getNumPerParticleParameters(); i++)
parameterNames.push_back(force.getPerParticleParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParameterNames.push_back(force.getGlobalParameterName(i));
......@@ -1191,7 +1226,7 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
}
set<string> variables;
variables.insert("r");
for (int i = 0; i < numParameters; i++) {
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
variables.insert(parameterNames[i]+"1");
variables.insert(parameterNames[i]+"2");
}
......@@ -1202,25 +1237,6 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
for (auto& function : functions)
delete function.second;
// Record information for the long range correction.
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) {
forceCopy = new CustomNonbondedForce(force);
hasInitializedLongRangeCorrection = false;
}
else {
longRangeCoefficient = 0.0;
hasInitializedLongRangeCorrection = true;
}
// Record the interaction groups.
for (int i = 0; i < force.getNumInteractionGroups(); i++) {
set<int> set1, set2;
force.getInteractionGroupParameters(i, set1, set2);
interactionGroups.push_back(make_pair(set1, set2));
}
}
double ReferenceCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
......@@ -1300,6 +1316,19 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
hasInitializedLongRangeCorrection = true;
*forceCopy = force;
}
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed)
createExpressions(force);
}
ReferenceCalcGBSAOBCForceKernel::~ReferenceCalcGBSAOBCForceKernel() {
......@@ -1395,11 +1424,10 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
// Build the arrays.
int numPerParticleParameters = force.getNumPerParticleParameters();
particleParamArray.resize(numParticles);
for (int i = 0; i < numParticles; ++i)
force.getParticleParameters(i, particleParamArray[i]);
for (int i = 0; i < numPerParticleParameters; i++)
for (int i = 0; i < force.getNumPerParticleParameters(); i++)
particleParameterNames.push_back(force.getPerParticleParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
......@@ -1410,14 +1438,32 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
else
neighborList = new NeighborList();
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the expressions.
createExpressions(force);
}
void ReferenceCalcCustomGBForceKernel::createExpressions(const CustomGBForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++)
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Parse the expressions for computed values.
valueExpressions.clear();
valueTypes.clear();
valueNames.clear();
energyParamDerivNames.clear();
valueDerivExpressions.clear();
valueGradientExpressions.clear();
valueParamDerivExpressions.clear();
valueDerivExpressions.resize(force.getNumComputedValues());
valueGradientExpressions.resize(force.getNumComputedValues());
valueParamDerivExpressions.resize(force.getNumComputedValues());
......@@ -1426,7 +1472,7 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
particleVariables.insert("x");
particleVariables.insert("y");
particleVariables.insert("z");
for (int i = 0; i < numPerParticleParameters; i++) {
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
particleVariables.insert(particleParameterNames[i]);
pairVariables.insert(particleParameterNames[i]+"1");
pairVariables.insert(particleParameterNames[i]+"2");
......@@ -1465,6 +1511,11 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
// Parse the expressions for energy terms.
energyExpressions.clear();
energyTypes.clear();
energyDerivExpressions.clear();
energyGradientExpressions.clear();
energyParamDerivExpressions.clear();
energyDerivExpressions.resize(force.getNumEnergyTerms());
energyGradientExpressions.resize(force.getNumEnergyTerms());
energyParamDerivExpressions.resize(force.getNumEnergyTerms());
......@@ -1540,6 +1591,19 @@ void ReferenceCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& cont
for (int j = 0; j < numParameters; j++)
particleParamArray[i][j] = parameters[j];
}
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed)
createExpressions(force);
}
ReferenceCalcCustomExternalForceKernel::~ReferenceCalcCustomExternalForceKernel() {
......@@ -1637,8 +1701,7 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
// Build the arrays.
vector<vector<int> > donorParticles(numDonors);
int numDonorParameters = force.getNumPerDonorParameters();
donorParticles.resize(numDonors);
donorParamArray.resize(numDonors);
for (int i = 0; i < numDonors; ++i) {
int d1, d2, d3;
......@@ -1647,8 +1710,7 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
donorParticles[i].push_back(d2);
donorParticles[i].push_back(d3);
}
vector<vector<int> > acceptorParticles(numAcceptors);
int numAcceptorParameters = force.getNumPerAcceptorParameters();
acceptorParticles.resize(numAcceptors);
acceptorParamArray.resize(numAcceptors);
for (int i = 0; i < numAcceptors; ++i) {
int a1, a2, a3;
......@@ -1657,13 +1719,25 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
acceptorParticles[i].push_back(a2);
acceptorParticles[i].push_back(a3);
}
NonbondedMethod nonbondedMethod = CalcCustomHbondForceKernel::NonbondedMethod(force.getNonbondedMethod());
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
nonbondedCutoff = force.getCutoffDistance();
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
createInteraction(force);
}
void ReferenceCalcCustomHbondForceKernel::createInteraction(const CustomHbondForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++)
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Parse the expression and create the object used to calculate the interaction.
......@@ -1674,13 +1748,12 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
Lepton::ParsedExpression energyExpression = CustomHbondForceImpl::prepareExpression(force, functions, distances, angles, dihedrals);
vector<string> donorParameterNames;
vector<string> acceptorParameterNames;
for (int i = 0; i < numDonorParameters; i++)
for (int i = 0; i < force.getNumPerDonorParameters(); i++)
donorParameterNames.push_back(force.getPerDonorParameterName(i));
for (int i = 0; i < numAcceptorParameters; i++)
for (int i = 0; i < force.getNumPerAcceptorParameters(); i++)
acceptorParameterNames.push_back(force.getPerAcceptorParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
ixn = new ReferenceCustomHbondIxn(donorParticles, acceptorParticles, energyExpression, donorParameterNames, acceptorParameterNames, distances, angles, dihedrals);
NonbondedMethod nonbondedMethod = CalcCustomHbondForceKernel::NonbondedMethod(force.getNonbondedMethod());
isPeriodic = (nonbondedMethod == CutoffPeriodic);
if (nonbondedMethod != NoCutoff)
ixn->setUseCutoff(nonbondedCutoff);
......@@ -1733,6 +1806,22 @@ void ReferenceCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& c
for (int j = 0; j < numAcceptorParameters; j++)
acceptorParamArray[i][j] = parameters[j];
}
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
createInteraction(force);
}
}
ReferenceCalcCustomCentroidBondForceKernel::~ReferenceCalcCustomCentroidBondForceKernel() {
......@@ -1746,23 +1835,32 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system
// Build the arrays.
int numGroups = force.getNumGroups();
vector<vector<int> > groupAtoms(numGroups);
groupAtoms.resize(numGroups);
vector<double> ignored;
for (int i = 0; i < numGroups; i++)
force.getGroupParameters(i, groupAtoms[i], ignored);
vector<vector<double> > normalizedWeights;
CustomCentroidBondForceImpl::computeNormalizedWeights(force, system, normalizedWeights);
numBonds = force.getNumBonds();
vector<vector<int> > bondGroups(numBonds);
int numBondParameters = force.getNumPerBondParameters();
bondGroups.resize(numBonds);
bondParamArray.resize(numBonds);
for (int i = 0; i < numBonds; ++i)
force.getBondParameters(i, bondGroups[i], bondParamArray[i]);
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
createInteraction(force);
}
void ReferenceCalcCustomCentroidBondForceKernel::createInteraction(const CustomCentroidBondForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++)
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create implementations of point functions.
......@@ -1773,9 +1871,10 @@ void ReferenceCalcCustomCentroidBondForceKernel::initialize(const System& system
// Parse the expression and create the object used to calculate the interaction.
int numGroups = force.getNumGroups();
Lepton::ParsedExpression energyExpression = CustomCentroidBondForceImpl::prepareExpression(force, functions);
vector<string> bondParameterNames;
for (int i = 0; i < numBondParameters; i++)
for (int i = 0; i < force.getNumPerBondParameters(); i++)
bondParameterNames.push_back(force.getPerBondParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
......@@ -1830,6 +1929,22 @@ void ReferenceCalcCustomCentroidBondForceKernel::copyParametersToContext(Context
for (int j = 0; j < numParameters; j++)
bondParamArray[i][j] = params[j];
}
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
createInteraction(force);
}
}
ReferenceCalcCustomCompoundBondForceKernel::~ReferenceCalcCustomCompoundBondForceKernel() {
......@@ -1843,16 +1958,26 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system
// Build the arrays.
numBonds = force.getNumBonds();
vector<vector<int> > bondParticles(numBonds);
int numBondParameters = force.getNumPerBondParameters();
bondParticles.resize(numBonds);
bondParamArray.resize(numBonds);
for (int i = 0; i < numBonds; ++i)
force.getBondParameters(i, bondParticles[i], bondParamArray[i]);
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
createInteraction(force);
}
void ReferenceCalcCustomCompoundBondForceKernel::createInteraction(const CustomCompoundBondForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++)
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Create implementations of point functions.
......@@ -1865,7 +1990,7 @@ void ReferenceCalcCustomCompoundBondForceKernel::initialize(const System& system
Lepton::ParsedExpression energyExpression = CustomCompoundBondForceImpl::prepareExpression(force, functions);
vector<string> bondParameterNames;
for (int i = 0; i < numBondParameters; i++)
for (int i = 0; i < force.getNumPerBondParameters(); i++)
bondParameterNames.push_back(force.getPerBondParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
......@@ -1920,6 +2045,22 @@ void ReferenceCalcCustomCompoundBondForceKernel::copyParametersToContext(Context
for (int j = 0; j < numParameters; j++)
bondParamArray[i][j] = params[j];
}
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
createInteraction(force);
}
}
ReferenceCalcCustomManyParticleForceKernel::~ReferenceCalcCustomManyParticleForceKernel() {
......@@ -1928,7 +2069,6 @@ ReferenceCalcCustomManyParticleForceKernel::~ReferenceCalcCustomManyParticleForc
}
void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system, const CustomManyParticleForce& force) {
// Build the arrays.
numParticles = system.getNumParticles();
......@@ -1939,6 +2079,14 @@ void ReferenceCalcCustomManyParticleForceKernel::initialize(const System& system
}
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
ixn = new ReferenceCustomManyParticleIxn(force);
nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod());
cutoffDistance = force.getCutoffDistance();
......@@ -1977,6 +2125,22 @@ void ReferenceCalcCustomManyParticleForceKernel::copyParametersToContext(Context
for (int j = 0; j < numParameters; j++)
particleParamArray[i][j] = parameters[j];
}
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
ixn = new ReferenceCustomManyParticleIxn(force);
}
}
ReferenceCalcGayBerneForceKernel::~ReferenceCalcGayBerneForceKernel() {
......
......@@ -205,6 +205,18 @@ void testComplexFunction(bool byGroups) {
for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], TOL);
}
// Try updating the tabulated function.
for (int i = 0; i < table.size(); i++)
table[i] *= 0.5;
dynamic_cast<Continuous1DFunction&>(compound->getTabulatedFunction(0)).setFunctionParameters(table, -1, 10);
dynamic_cast<Continuous1DFunction&>(centroid->getTabulatedFunction(0)).setFunctionParameters(table, -1, 10);
compound->updateParametersInContext(context);
centroid->updateParametersInContext(context);
State state1 = context.getState(State::Energy, false, 1<<0);
State state2 = context.getState(State::Energy, false, 1<<1);
ASSERT_EQUAL_TOL(state1.getPotentialEnergy(), state2.getPotentialEnergy(), TOL);
}
void testCustomWeights() {
......
......@@ -212,6 +212,31 @@ void testContinuous2DFunction() {
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
// Try updating the tabulated function.
for (int i = 0; i < table.size(); i++)
table[i] *= 0.5;
Continuous2DFunction& fn = dynamic_cast<Continuous2DFunction&>(forceField->getTabulatedFunction(0));
fn.setFunctionParameters(xsize, ysize, table, xmin, xmax, ymin, ymax);
forceField->updateParametersInContext(context);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
positions[0] = Vec3(x, y, 1.5);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
Vec3 force(0, 0, 0);
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) {
energy = 0.5*sin(0.25*x)*cos(0.33*y)+1;
force[0] = 0.5*(-0.25*cos(0.25*x)*cos(0.33*y));
force[1] = 0.5*0.3*sin(0.25*x)*sin(0.33*y);
}
ASSERT_EQUAL_VEC(force, forces[0], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
void testContinuous3DFunction() {
......
......@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -254,7 +254,7 @@ void testMembrane() {
double norm = 0.0;
for (int i = 0; i < (int) forces.size(); ++i)
norm += forces[i].dot(forces[i]);
norm = std::sqrt(norm);
norm = sqrt(norm);
const double stepSize = 1e-2;
double step = 0.5*stepSize/norm;
vector<Vec3> positions2(numParticles), positions3(numParticles);
......@@ -283,7 +283,7 @@ void testTabulatedFunction() {
force->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
table.push_back(sin(0.25*i));
force->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(force);
Context context(system, integrator, platform);
......@@ -296,8 +296,8 @@ void testTabulatedFunction() {
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double force = (x < 1.0 || x > 6.0 ? 0.0 : -std::cos(x-1.0));
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0;
double force = (x < 1.0 || x > 6.0 ? 0.0 : -cos(x-1.0));
double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
......@@ -308,7 +308,22 @@ void testTabulatedFunction() {
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0;
double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
// Try updating the tabulated function.
for (int i = 0; i < table.size(); i++)
table[i] *= 0.5;
dynamic_cast<Continuous1DFunction&>(force->getTabulatedFunction(0)).setFunctionParameters(table, 1.0, 6.0);
force->updateParametersInContext(context);
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : 0.5*sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
}
......@@ -385,7 +400,7 @@ void testPositionDependence() {
double norm = 0.0;
for (int i = 0; i < (int) forces.size(); ++i)
norm += forces[i].dot(forces[i]);
norm = std::sqrt(norm);
norm = sqrt(norm);
const double stepSize = 1e-3;
double step = 0.5*stepSize/norm;
vector<Vec3> positions2(2), positions3(2);
......@@ -455,7 +470,7 @@ void testExclusions() {
double norm = 0.0;
for (int i = 0; i < (int) forces.size(); ++i)
norm += forces[i].dot(forces[i]);
norm = std::sqrt(norm);
norm = sqrt(norm);
if (norm > 0) {
const double stepSize = 1e-3;
double step = stepSize/norm;
......
......@@ -223,6 +223,15 @@ void testCustomFunctions() {
ASSERT_EQUAL_VEC(Vec3(0, -0.1, 0), forces[1], TOL);
ASSERT_EQUAL_VEC(Vec3(-0.1, 0, 0), forces[2], TOL);
ASSERT_EQUAL_TOL(0.1*2+0.1*2, state.getPotentialEnergy(), TOL);
// Try updating the tabulated function.
for (int i = 0; i < function.size(); i++)
function[i] *= 0.5;
dynamic_cast<Continuous1DFunction&>(custom->getTabulatedFunction(0)).setFunctionParameters(function, 0, 10);
custom->updateParametersInContext(context);
state = context.getState(State::Energy);
ASSERT_EQUAL_TOL(0.5*(0.1*2+0.1*2), state.getPotentialEnergy(), TOL);
}
void test2DFunction() {
......
......@@ -516,6 +516,15 @@ void testTabulatedFunctions() {
expectedEnergy += 0.5*(r12+r13+r23)*(c[i]+c[j]+c[k]);
}
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
// Try updating the tabulated function.
for (int i = 0; i < values.size(); i++)
values[i] *= 0.5;
dynamic_cast<Discrete3DFunction&>(force->getTabulatedFunction(1)).setFunctionParameters(numParticles, numParticles, numParticles, values);
force->updateParametersInContext(context);
state = context.getState(State::Energy);
ASSERT_EQUAL_TOL(0.5*expectedEnergy, state.getPotentialEnergy(), 1e-5);
}
void testTypeFilters() {
......
......@@ -355,6 +355,21 @@ void testContinuous1DFunction() {
double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
// Try updating the tabulated function.
for (int i = 0; i < table.size(); i++)
table[i] *= 0.5;
dynamic_cast<Continuous1DFunction&>(forceField->getTabulatedFunction(0)).setFunctionParameters(table, 1.0, 6.0);
forceField->updateParametersInContext(context);
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : 0.5*sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
}
void testPeriodicContinuous1DFunction() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment