Unverified Commit 27bcb657 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

updateParametersInContext() can change tabulated functions (#3307)

* updateParametersInContext() can change tabulated functions

* Fixed error in building C wrappers

* updateParametersInContext() can change tabulated functions for CustomCentroidBondForce

* CustomNonbondedForce can update tabulated functions

* CustomGBForce can update tabulated functions

* CustomManyParticleForce can update tabulated functions

* CustomHbondForce can update tabulated functions
parent d83c2724
......@@ -358,15 +358,16 @@ public:
*/
const std::string& getTabulatedFunctionName(int index) const;
/**
* Update the per-bond parameters in a Context to match those stored in this Force object. This method provides
* Update the per-bond parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setBondParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the values of per-bond parameters.
* All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* This method has several limitations. The only information it updates is the values of per-bond parameters and tabulated
* functions. All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. Neither the definitions of groups nor the set of groups involved in a bond can be changed, nor can new
* bonds be added.
* bonds be added. Also, while the tabulated values of a function can change, everything else about it (its dimensions,
* the data range) must not be changed.
*/
void updateParametersInContext(Context& context);
/**
......
......@@ -338,14 +338,15 @@ public:
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Update the per-bond parameters in a Context to match those stored in this Force object. This method provides
* Update the per-bond parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setBondParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the values of per-bond parameters.
* All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. The set of particles involved in a bond cannot be changed, nor can new bonds be added.
* This method has several limitations. The only information it updates is the values of per-bond parameters and tabulated
* functions. All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. The set of particles involved in a bond cannot be changed, nor can new bonds be added. Also, while the
* tabulated values of a function can change, everything else about it (its dimensions, the data range) must not be changed.
*/
void updateParametersInContext(Context& context);
/**
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -544,14 +544,15 @@ public:
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Update the per-particle parameters in a Context to match those stored in this Force object. This method provides
* Update the per-particle parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the values of per-particle parameters.
* All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. Also, this method cannot be used to add new particles, only to change the parameters of existing ones.
* This method has several limitations. The only information it updates is the values of per-particle parameters and tabulated
* functions. All other aspects of the Force (such as the energy function) are unaffected and can only be changed by reinitializing
* the Context. Also, this method cannot be used to add new particles, only to change the parameters of existing ones. While
* the tabulated values of a function can change, everything else about it (its dimensions, the data range) must not be changed.
*/
void updateParametersInContext(Context& context);
/**
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -443,15 +443,16 @@ public:
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Update the per-donor and per-acceptor parameters in a Context to match those stored in this Force object. This method
* Update the per-donor and per-acceptor parameters and tabulated functions in a Context to match those stored in this Force object. This method
* provides an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setDonorParameters() and setAcceptorParameters() to modify this object's parameters, then call
* updateParametersInContext() to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the values of per-donor and per-acceptor parameters.
* All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can only
* This method has several limitations. The only information it updates is the values of per-donor and per-acceptor parameters and tabulated
* functions. All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can only
* be changed by reinitializing the Context. The set of particles involved in a donor or acceptor cannot be changed, nor can
* new donors or acceptors be added.
* new donors or acceptors be added. While the tabulated values of a function can change, everything else about it (its dimensions,
* the data range) must not be changed.
*/
void updateParametersInContext(Context& context);
/**
......
......@@ -480,15 +480,16 @@ public:
*/
const std::string& getTabulatedFunctionName(int index) const;
/**
* Update the per-particle parameters in a Context to match those stored in this Force object. This method provides
* Update the per-particle parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the values of per-particle parameters.
* All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can
* This method has several limitations. The only information it updates is the values of per-particle parameters and tabulated
* functions. All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can
* only be changed by reinitializing the Context. Also, this method cannot be used to add new particles, only to change
* the parameters of existing ones.
* the parameters of existing ones. While the tabulated values of a function can change, everything else about it (its dimensions,
* the data range) must not be changed.
*/
void updateParametersInContext(Context& context);
/**
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -490,15 +490,16 @@ public:
*/
void setInteractionGroupParameters(int index, const std::set<int>& set1, const std::set<int>& set2);
/**
* Update the per-particle parameters in a Context to match those stored in this Force object. This method provides
* Update the per-particle parameters and tabulated functions in a Context to match those stored in this Force object. This method provides
* an efficient method to update certain parameters in an existing Context without needing to reinitialize it.
* Simply call setParticleParameters() to modify this object's parameters, then call updateParametersInContext()
* to copy them over to the Context.
*
* This method has several limitations. The only information it updates is the values of per-particle parameters.
* All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can
* This method has several limitations. The only information it updates is the values of per-particle parameters and tabulated
* functions. All other aspects of the Force (the energy function, nonbonded method, cutoff distance, etc.) are unaffected and can
* only be changed by reinitializing the Context. Also, this method cannot be used to add new particles, only to change
* the parameters of existing ones.
* the parameters of existing ones. While the tabulated values of a function can change, everything else about it (its dimensions,
* the data range) must not be changed.
*/
void updateParametersInContext(Context& context);
/**
......
......@@ -65,9 +65,12 @@ public:
virtual TabulatedFunction* Copy() const = 0;
/**
* Get the periodicity status of the tabulated function.
*
*/
bool getPeriodic() const;
virtual bool operator==(const TabulatedFunction& other) const = 0;
virtual bool operator!=(const TabulatedFunction& other) const {
return !(*this == other);
}
protected:
bool periodic;
};
......@@ -114,6 +117,7 @@ public:
* @deprecated This will be removed in a future release.
*/
Continuous1DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private:
std::vector<double> values;
double min, max;
......@@ -176,6 +180,7 @@ public:
* @deprecated This will be removed in a future release.
*/
Continuous2DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private:
std::vector<double> values;
int xsize, ysize;
......@@ -254,6 +259,7 @@ public:
* @deprecated This will be removed in a future release.
*/
Continuous3DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private:
std::vector<double> values;
int xsize, ysize, zsize;
......@@ -291,6 +297,7 @@ public:
* @deprecated This will be removed in a future release.
*/
Discrete1DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private:
std::vector<double> values;
};
......@@ -335,6 +342,7 @@ public:
* @deprecated This will be removed in a future release.
*/
Discrete2DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private:
int xsize, ysize;
std::vector<double> values;
......@@ -383,6 +391,7 @@ public:
* @deprecated This will be removed in a future release.
*/
Discrete3DFunction* Copy() const;
bool operator==(const TabulatedFunction& other) const;
private:
int xsize, ysize, zsize;
std::vector<double> values;
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2014 Stanford University and the Authors. *
* Portions copyright (c) 2014-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -75,6 +75,15 @@ Continuous1DFunction* Continuous1DFunction::Copy() const {
return new Continuous1DFunction(new_vec, min, max);
}
bool Continuous1DFunction::operator==(const TabulatedFunction& other) const {
const Continuous1DFunction* fn = dynamic_cast<const Continuous1DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->min != min || fn->max != max)
return false;
return (fn->values == values);
}
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, bool periodic) {
this->periodic = periodic;
setFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
......@@ -120,6 +129,19 @@ Continuous2DFunction* Continuous2DFunction::Copy() const {
return new Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax);
}
bool Continuous2DFunction::operator==(const TabulatedFunction& other) const {
const Continuous2DFunction* fn = dynamic_cast<const Continuous2DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->xsize != xsize || fn->ysize != ysize)
return false;
if (fn->xmin != xmin || fn->xmax != xmax)
return false;
if (fn->ymin != ymin || fn->ymax != ymax)
return false;
return (fn->values == values);
}
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax, bool periodic) {
this->periodic = periodic;
setFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
......@@ -173,6 +195,20 @@ Continuous3DFunction* Continuous3DFunction::Copy() const {
return new Continuous3DFunction(xsize, ysize, zsize, new_vec, xmin, xmax, ymin, ymax, zmin, zmax);
}
bool Continuous3DFunction::operator==(const TabulatedFunction& other) const {
const Continuous3DFunction* fn = dynamic_cast<const Continuous3DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->xsize != xsize || fn->ysize != ysize || fn->zsize != zsize)
return false;
if (fn->xmin != xmin || fn->xmax != xmax)
return false;
if (fn->ymin != ymin || fn->ymax != ymax)
return false;
if (fn->zmin != zmin || fn->zmax != zmax)
return false;
return (fn->values == values);
}
Discrete1DFunction::Discrete1DFunction(const vector<double>& values) {
this->values = values;
......@@ -193,6 +229,13 @@ Discrete1DFunction* Discrete1DFunction::Copy() const {
return new Discrete1DFunction(new_vec);
}
bool Discrete1DFunction::operator==(const TabulatedFunction& other) const {
const Discrete1DFunction* fn = dynamic_cast<const Discrete1DFunction*>(&other);
if (fn == NULL)
return false;
return (fn->values == values);
}
Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const vector<double>& values) {
if (values.size() != xsize*ysize)
throw OpenMMException("Discrete2DFunction: incorrect number of values");
......@@ -222,6 +265,15 @@ Discrete2DFunction* Discrete2DFunction::Copy() const {
return new Discrete2DFunction(xsize, ysize, new_vec);
}
bool Discrete2DFunction::operator==(const TabulatedFunction& other) const {
const Discrete2DFunction* fn = dynamic_cast<const Discrete2DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->xsize != xsize || fn->ysize != ysize)
return false;
return (fn->values == values);
}
Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const vector<double>& values) {
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Discrete3DFunction: incorrect number of values");
......@@ -253,3 +305,12 @@ Discrete3DFunction* Discrete3DFunction::Copy() const {
new_vec[i] = values[i];
return new Discrete3DFunction(xsize, ysize, zsize, new_vec);
}
bool Discrete3DFunction::operator==(const TabulatedFunction& other) const {
const Discrete3DFunction* fn = dynamic_cast<const Discrete3DFunction*>(&other);
if (fn == NULL)
return false;
if (fn->xsize != xsize || fn->ysize != ysize || fn->zsize != zsize)
return false;
return (fn->values == values);
}
......@@ -529,7 +529,8 @@ private:
ComputeArray globals;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
const System& system;
};
......@@ -577,7 +578,8 @@ private:
ComputeArray groupForces, bondGroups, centerPositions;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::vector<void*> groupForcesArgs;
ComputeKernel computeCentersKernel, groupForcesKernel, applyForcesKernel;
const System& system;
......@@ -628,7 +630,8 @@ private:
std::vector<void*> interactionGroupArgs;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
double longRangeCoefficient;
std::vector<double> longRangeCoefficientDerivs;
bool hasInitializedLongRangeCorrection, hasInitializedKernel, hasParamDerivs, useNeighborList;
......@@ -728,7 +731,8 @@ private:
ComputeArray longEnergyDerivs, globals, valueBuffers;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
std::vector<bool> pairValueUsesParam, pairEnergyUsesParam, pairEnergyUsesValue;
const System& system;
ComputeKernel pairValueKernel, perParticleValueKernel, pairEnergyKernel, perParticleEnergyKernel, gradientChainRuleKernel;
......@@ -785,7 +789,8 @@ private:
ComputeArray acceptorExclusions;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
const System& system;
ComputeKernel donorKernel, acceptorKernel;
};
......@@ -836,7 +841,8 @@ private:
ComputeArray neighborPairs, numNeighborPairs, neighborStartIndex, numNeighborsForAtom, neighbors;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<ComputeArray> tabulatedFunctions;
std::vector<ComputeArray> tabulatedFunctionArrays;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
const System& system;
ComputeKernel forceKernel, blockBoundsKernel, neighborsKernel, startIndicesKernel, copyPairsKernel;
ComputeEvent event;
......
This diff is collapsed.
......@@ -320,7 +320,8 @@ public:
* @param force the CustomNonbondedForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force);
private:
private:
void createInteraction(const CustomNonbondedForce& force);
CpuPlatform::PlatformData& data;
int numParticles;
std::vector<std::vector<double> > particleParamArray;
......@@ -333,6 +334,7 @@ private:
std::vector<std::string> parameterNames, globalParameterNames, energyParamDerivNames;
std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups;
std::vector<double> longRangeCoefficientDerivs;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod;
CpuCustomNonbondedForce* nonbonded;
};
......@@ -410,6 +412,7 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomGBForce& force);
private:
void createInteraction(const CustomGBForce& force);
CpuPlatform::PlatformData& data;
int numParticles;
bool isPeriodic;
......@@ -421,6 +424,7 @@ private:
std::vector<std::string> particleParameterNames, globalParameterNames, energyParamDerivNames, valueNames;
std::vector<OpenMM::CustomGBForce::ComputationType> valueTypes;
std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod;
};
......@@ -463,6 +467,7 @@ private:
std::vector<std::vector<double> > particleParamArray;
CpuCustomManyParticleForce* ixn;
std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod;
};
......
......@@ -45,6 +45,7 @@
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/vectorize.h"
#include "openmm/serialization/XmlSerializer.h"
#include "lepton/CompiledExpression.h"
#include "lepton/CustomFunction.h"
#include "lepton/Operation.h"
......@@ -868,7 +869,6 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
// Build the arrays.
int numParameters = force.getNumPerParticleParameters();
particleParamArray.resize(numParticles);
for (int i = 0; i < numParticles; ++i)
force.getParticleParameters(i, particleParamArray[i]);
......@@ -882,10 +882,41 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
switchingDistance = force.getSwitchingDistance();
}
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Record information for the long range correction.
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) {
forceCopy = new CustomNonbondedForce(force);
hasInitializedLongRangeCorrection = false;
}
else {
longRangeCoefficient = 0.0;
hasInitializedLongRangeCorrection = true;
}
// Record the interaction groups.
for (int i = 0; i < force.getNumInteractionGroups(); i++) {
set<int> set1, set2;
force.getInteractionGroupParameters(i, set1, set2);
interactionGroups.push_back(make_pair(set1, set2));
}
data.isPeriodic |= (nonbondedMethod == CutoffPeriodic);
// Create the interaction.
createInteraction(force);
}
void CpuCalcCustomNonbondedForceKernel::createInteraction(const CustomNonbondedForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++)
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Parse the various expressions used to calculate the force.
......@@ -893,7 +924,7 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
Lepton::CompiledExpression energyExpression = expression.createCompiledExpression();
Lepton::CompiledExpression forceExpression = expression.differentiate("r").createCompiledExpression();
for (int i = 0; i < numParameters; i++)
for (int i = 0; i < force.getNumPerParticleParameters(); i++)
parameterNames.push_back(force.getPerParticleParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
globalParameterNames.push_back(force.getGlobalParameterName(i));
......@@ -907,7 +938,7 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
}
set<string> variables;
variables.insert("r");
for (int i = 0; i < numParameters; i++) {
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
variables.insert(parameterNames[i]+"1");
variables.insert(parameterNames[i]+"2");
}
......@@ -918,26 +949,9 @@ void CpuCalcCustomNonbondedForceKernel::initialize(const System& system, const C
for (auto& function : functions)
delete function.second;
// Record information for the long range correction.
if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection()) {
forceCopy = new CustomNonbondedForce(force);
hasInitializedLongRangeCorrection = false;
}
else {
longRangeCoefficient = 0.0;
hasInitializedLongRangeCorrection = true;
}
// Record the interaction groups.
for (int i = 0; i < force.getNumInteractionGroups(); i++) {
set<int> set1, set2;
force.getInteractionGroupParameters(i, set1, set2);
interactionGroups.push_back(make_pair(set1, set2));
}
data.isPeriodic |= (nonbondedMethod == CutoffPeriodic);
// Create the object that computes the interaction.
nonbonded = new CpuCustomNonbondedForce(energyExpression, forceExpression, parameterNames, exclusions, energyParamDerivExpressions, data.threads);
if (interactionGroups.size() > 0)
nonbonded->setInteractionGroups(interactionGroups);
......@@ -1011,6 +1025,22 @@ void CpuCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& con
hasInitializedLongRangeCorrection = true;
*forceCopy = force;
}
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete nonbonded;
nonbonded = NULL;
createInteraction(force);
}
}
CpuCalcGBSAOBCForceKernel::~CpuCalcGBSAOBCForceKernel() {
......@@ -1101,11 +1131,10 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
// Build the arrays.
int numPerParticleParameters = force.getNumPerParticleParameters();
particleParamArray.resize(numParticles);
for (int i = 0; i < numParticles; ++i)
force.getParticleParameters(i, particleParamArray[i]);
for (int i = 0; i < numPerParticleParameters; i++)
for (int i = 0; i < force.getNumPerParticleParameters(); i++)
particleParameterNames.push_back(force.getPerParticleParameterName(i));
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
......@@ -1113,15 +1142,30 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
nonbondedCutoff = force.getCutoffDistance();
if (nonbondedMethod != NoCutoff)
neighborList = new CpuNeighborList(4);
data.isPeriodic |= (force.getNonbondedMethod() == CustomGBForce::CutoffPeriodic);
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
createInteraction(force);
}
void CpuCalcCustomGBForceKernel::createInteraction(const CustomGBForce& force) {
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++)
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));
// Parse the expressions for computed values.
valueTypes.clear();
valueNames.clear();
energyParamDerivNames.clear();
vector<vector<Lepton::CompiledExpression> > valueDerivExpressions(force.getNumComputedValues());
vector<vector<Lepton::CompiledExpression> > valueGradientExpressions(force.getNumComputedValues());
vector<vector<Lepton::CompiledExpression> > valueParamDerivExpressions(force.getNumComputedValues());
......@@ -1132,7 +1176,7 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
particleVariables.insert("x");
particleVariables.insert("y");
particleVariables.insert("z");
for (int i = 0; i < numPerParticleParameters; i++) {
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
particleVariables.insert(particleParameterNames[i]);
pairVariables.insert(particleParameterNames[i]+"1");
pairVariables.insert(particleParameterNames[i]+"2");
......@@ -1171,6 +1215,7 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
// Parse the expressions for energy terms.
energyTypes.clear();
vector<vector<Lepton::CompiledExpression> > energyDerivExpressions(force.getNumEnergyTerms());
vector<vector<Lepton::CompiledExpression> > energyGradientExpressions(force.getNumEnergyTerms());
vector<vector<Lepton::CompiledExpression> > energyParamDerivExpressions(force.getNumEnergyTerms());
......@@ -1208,7 +1253,6 @@ void CpuCalcCustomGBForceKernel::initialize(const System& system, const CustomGB
ixn = new CpuCustomGBForce(numParticles, exclusions, valueExpressions, valueDerivExpressions, valueGradientExpressions, valueParamDerivExpressions,
valueNames, valueTypes, energyExpressions, energyDerivExpressions, energyGradientExpressions, energyParamDerivExpressions, energyTypes,
particleParameterNames, data.threads);
data.isPeriodic |= (force.getNonbondedMethod() == CustomGBForce::CutoffPeriodic);
}
double CpuCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
......@@ -1247,6 +1291,22 @@ void CpuCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, c
for (int j = 0; j < numParameters; j++)
particleParamArray[i][j] = static_cast<double>(parameters[j]);
}
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
createInteraction(force);
}
}
CpuCalcCustomManyParticleForceKernel::~CpuCalcCustomManyParticleForceKernel() {
......@@ -1266,6 +1326,14 @@ void CpuCalcCustomManyParticleForceKernel::initialize(const System& system, cons
}
for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i));
// Record the tabulated functions for future reference.
for (int i = 0; i < force.getNumTabulatedFunctions(); i++)
tabulatedFunctions[force.getTabulatedFunctionName(i)] = XmlSerializer::clone(force.getTabulatedFunction(i));
// Create the interaction.
ixn = new CpuCustomManyParticleForce(force, data.threads);
nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod());
cutoffDistance = force.getCutoffDistance();
......@@ -1303,6 +1371,22 @@ void CpuCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl&
for (int j = 0; j < numParameters; j++)
particleParamArray[i][j] = static_cast<double>(parameters[j]);
}
// See if any tabulated functions have changed.
bool changed = false;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
string name = force.getTabulatedFunctionName(i);
if (force.getTabulatedFunction(i) != *tabulatedFunctions[name]) {
tabulatedFunctions[name] = XmlSerializer::clone(force.getTabulatedFunction(i));
changed = true;
}
}
if (changed) {
delete ixn;
ixn = NULL;
ixn = new CpuCustomManyParticleForce(force, data.threads);
}
}
CpuCalcGayBerneForceKernel::~CpuCalcGayBerneForceKernel() {
......
......@@ -682,6 +682,7 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force);
private:
void createExpressions(const CustomNonbondedForce& force);
int numParticles;
std::vector<std::vector<double> > particleParamArray;
double nonbondedCutoff, switchingDistance, periodicBoxSize[3], longRangeCoefficient;
......@@ -695,6 +696,7 @@ private:
std::vector<std::string> parameterNames, globalParameterNames, energyParamDerivNames;
std::vector<std::pair<std::set<int>, std::set<int> > > interactionGroups;
std::vector<double> longRangeCoefficientDerivs;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod;
NeighborList* neighborList;
};
......@@ -768,6 +770,7 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomGBForce& force);
private:
void createExpressions(const CustomGBForce& force);
int numParticles;
bool isPeriodic;
std::vector<std::vector<double> > particleParamArray;
......@@ -784,6 +787,7 @@ private:
std::vector<std::vector<Lepton::CompiledExpression> > energyGradientExpressions;
std::vector<std::vector<Lepton::CompiledExpression> > energyParamDerivExpressions;
std::vector<OpenMM::CustomGBForce::ComputationType> energyTypes;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod;
NeighborList* neighborList;
};
......@@ -861,13 +865,16 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomHbondForce& force);
private:
void createInteraction(const CustomHbondForce& force);
int numDonors, numAcceptors, numParticles;
bool isPeriodic;
std::vector<std::vector<int> > donorParticles, acceptorParticles;
std::vector<std::vector<double> > donorParamArray, acceptorParamArray;
double nonbondedCutoff;
ReferenceCustomHbondIxn* ixn;
std::vector<std::set<int> > exclusions;
std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
};
/**
......@@ -902,10 +909,15 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomCentroidBondForce& force);
private:
void createInteraction(const CustomCentroidBondForce& force);
int numBonds, numParticles;
std::vector<std::vector<int> > bondGroups;
std::vector<std::vector<int> > groupAtoms;
std::vector<std::vector<double> > normalizedWeights;
std::vector<std::vector<double> > bondParamArray;
ReferenceCustomCentroidBondIxn* ixn;
std::vector<std::string> globalParameterNames, energyParamDerivNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
bool usePeriodic;
Vec3* boxVectors;
};
......@@ -942,10 +954,13 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomCompoundBondForce& force);
private:
void createInteraction(const CustomCompoundBondForce& force);
int numBonds;
std::vector<std::vector<int> > bondParticles;
std::vector<std::vector<double> > bondParamArray;
ReferenceCustomCompoundBondIxn* ixn;
std::vector<std::string> globalParameterNames, energyParamDerivNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
bool usePeriodic;
Vec3* boxVectors;
};
......@@ -987,6 +1002,7 @@ private:
std::vector<std::vector<double> > particleParamArray;
ReferenceCustomManyParticleIxn* ixn;
std::vector<std::string> globalParameterNames;
std::map<std::string, const TabulatedFunction*> tabulatedFunctions;
NonbondedMethod nonbondedMethod;
};
......
......@@ -205,6 +205,18 @@ void testComplexFunction(bool byGroups) {
for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], TOL);
}
// Try updating the tabulated function.
for (int i = 0; i < table.size(); i++)
table[i] *= 0.5;
dynamic_cast<Continuous1DFunction&>(compound->getTabulatedFunction(0)).setFunctionParameters(table, -1, 10);
dynamic_cast<Continuous1DFunction&>(centroid->getTabulatedFunction(0)).setFunctionParameters(table, -1, 10);
compound->updateParametersInContext(context);
centroid->updateParametersInContext(context);
State state1 = context.getState(State::Energy, false, 1<<0);
State state2 = context.getState(State::Energy, false, 1<<1);
ASSERT_EQUAL_TOL(state1.getPotentialEnergy(), state2.getPotentialEnergy(), TOL);
}
void testCustomWeights() {
......
......@@ -212,6 +212,31 @@ void testContinuous2DFunction() {
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
// Try updating the tabulated function.
for (int i = 0; i < table.size(); i++)
table[i] *= 0.5;
Continuous2DFunction& fn = dynamic_cast<Continuous2DFunction&>(forceField->getTabulatedFunction(0));
fn.setFunctionParameters(xsize, ysize, table, xmin, xmax, ymin, ymax);
forceField->updateParametersInContext(context);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
positions[0] = Vec3(x, y, 1.5);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
Vec3 force(0, 0, 0);
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax) {
energy = 0.5*sin(0.25*x)*cos(0.33*y)+1;
force[0] = 0.5*(-0.25*cos(0.25*x)*cos(0.33*y));
force[1] = 0.5*0.3*sin(0.25*x)*sin(0.33*y);
}
ASSERT_EQUAL_VEC(force, forces[0], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
void testContinuous3DFunction() {
......
......@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -254,7 +254,7 @@ void testMembrane() {
double norm = 0.0;
for (int i = 0; i < (int) forces.size(); ++i)
norm += forces[i].dot(forces[i]);
norm = std::sqrt(norm);
norm = sqrt(norm);
const double stepSize = 1e-2;
double step = 0.5*stepSize/norm;
vector<Vec3> positions2(numParticles), positions3(numParticles);
......@@ -283,7 +283,7 @@ void testTabulatedFunction() {
force->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
table.push_back(sin(0.25*i));
force->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(force);
Context context(system, integrator, platform);
......@@ -296,8 +296,8 @@ void testTabulatedFunction() {
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double force = (x < 1.0 || x > 6.0 ? 0.0 : -std::cos(x-1.0));
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0;
double force = (x < 1.0 || x > 6.0 ? 0.0 : -cos(x-1.0));
double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
......@@ -308,7 +308,22 @@ void testTabulatedFunction() {
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0;
double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
// Try updating the tabulated function.
for (int i = 0; i < table.size(); i++)
table[i] *= 0.5;
dynamic_cast<Continuous1DFunction&>(force->getTabulatedFunction(0)).setFunctionParameters(table, 1.0, 6.0);
force->updateParametersInContext(context);
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : 0.5*sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
}
......@@ -385,7 +400,7 @@ void testPositionDependence() {
double norm = 0.0;
for (int i = 0; i < (int) forces.size(); ++i)
norm += forces[i].dot(forces[i]);
norm = std::sqrt(norm);
norm = sqrt(norm);
const double stepSize = 1e-3;
double step = 0.5*stepSize/norm;
vector<Vec3> positions2(2), positions3(2);
......@@ -455,7 +470,7 @@ void testExclusions() {
double norm = 0.0;
for (int i = 0; i < (int) forces.size(); ++i)
norm += forces[i].dot(forces[i]);
norm = std::sqrt(norm);
norm = sqrt(norm);
if (norm > 0) {
const double stepSize = 1e-3;
double step = stepSize/norm;
......
......@@ -223,6 +223,15 @@ void testCustomFunctions() {
ASSERT_EQUAL_VEC(Vec3(0, -0.1, 0), forces[1], TOL);
ASSERT_EQUAL_VEC(Vec3(-0.1, 0, 0), forces[2], TOL);
ASSERT_EQUAL_TOL(0.1*2+0.1*2, state.getPotentialEnergy(), TOL);
// Try updating the tabulated function.
for (int i = 0; i < function.size(); i++)
function[i] *= 0.5;
dynamic_cast<Continuous1DFunction&>(custom->getTabulatedFunction(0)).setFunctionParameters(function, 0, 10);
custom->updateParametersInContext(context);
state = context.getState(State::Energy);
ASSERT_EQUAL_TOL(0.5*(0.1*2+0.1*2), state.getPotentialEnergy(), TOL);
}
void test2DFunction() {
......
......@@ -516,6 +516,15 @@ void testTabulatedFunctions() {
expectedEnergy += 0.5*(r12+r13+r23)*(c[i]+c[j]+c[k]);
}
ASSERT_EQUAL_TOL(expectedEnergy, state.getPotentialEnergy(), 1e-5);
// Try updating the tabulated function.
for (int i = 0; i < values.size(); i++)
values[i] *= 0.5;
dynamic_cast<Discrete3DFunction&>(force->getTabulatedFunction(1)).setFunctionParameters(numParticles, numParticles, numParticles, values);
force->updateParametersInContext(context);
state = context.getState(State::Energy);
ASSERT_EQUAL_TOL(0.5*expectedEnergy, state.getPotentialEnergy(), 1e-5);
}
void testTypeFilters() {
......
......@@ -355,6 +355,21 @@ void testContinuous1DFunction() {
double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
// Try updating the tabulated function.
for (int i = 0; i < table.size(); i++)
table[i] *= 0.5;
dynamic_cast<Continuous1DFunction&>(forceField->getTabulatedFunction(0)).setFunctionParameters(table, 1.0, 6.0);
forceField->updateParametersInContext(context);
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : 0.5*sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
}
void testPeriodicContinuous1DFunction() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment