Commit 8f0e6d3c authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented tabulated functions in reference platform

parent 947455cd
...@@ -97,6 +97,12 @@ public: ...@@ -97,6 +97,12 @@ public:
int getNumGlobalParameters() const { int getNumGlobalParameters() const {
return globalParameters.size(); return globalParameters.size();
} }
/**
* Get the number of tabulated functions that have been defined.
*/
int getNumFunctions() const {
return functions.size();
}
/** /**
* Get the algebraic expression that gives the interaction energy between two particles * Get the algebraic expression that gives the interaction energy between two particles
*/ */
...@@ -273,6 +279,45 @@ public: ...@@ -273,6 +279,45 @@ public:
* will cause the interaction to be completely omitted from force and energy calculations. * will cause the interaction to be completely omitted from force and energy calculations.
*/ */
void setExceptionParameters(int index, int particle1, int particle2, const std::vector<double>& parameters); void setExceptionParameters(int index, int particle1, int particle2, const std::vector<double>& parameters);
/**
* Add a tabulated function that may appear in algebraic expressions.
*
* @param name the name of the function as it appears in expressions
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
* @return the index of the function that was added
*/
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating);
/**
* Get the parameters for a tabulated function that may appear in algebraic expressions.
*
* @param index the index of the function for which to get parameters
* @param name the name of the function as it appears in expressions
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const;
/**
* Set the parameters for a tabulated function that may appear in algebraic expressions.
*
* @param index the index of the function for which to set parameters
* @param name the name of the function as it appears in expressions
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min and max.
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating);
protected: protected:
ForceImpl* createImpl(); ForceImpl* createImpl();
private: private:
...@@ -280,6 +325,7 @@ private: ...@@ -280,6 +325,7 @@ private:
class ParameterInfo; class ParameterInfo;
class GlobalParameterInfo; class GlobalParameterInfo;
class ExceptionInfo; class ExceptionInfo;
class FunctionInfo;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
double cutoffDistance; double cutoffDistance;
Vec3 periodicBoxVectors[3]; Vec3 periodicBoxVectors[3];
...@@ -288,6 +334,7 @@ private: ...@@ -288,6 +334,7 @@ private:
std::vector<GlobalParameterInfo> globalParameters; std::vector<GlobalParameterInfo> globalParameters;
std::vector<ParticleInfo> particles; std::vector<ParticleInfo> particles;
std::vector<ExceptionInfo> exceptions; std::vector<ExceptionInfo> exceptions;
std::vector<FunctionInfo> functions;
std::map<std::pair<int, int>, int> exceptionMap; std::map<std::pair<int, int>, int> exceptionMap;
}; };
...@@ -331,6 +378,19 @@ public: ...@@ -331,6 +378,19 @@ public:
} }
}; };
class CustomNonbondedForce::FunctionInfo {
public:
std::string name;
std::vector<double> values;
double min, max;
bool interpolating;
FunctionInfo() {
}
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) :
name(name), values(values), min(min), max(max), interpolating(interpolating) {
}
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_CUSTOMNONBONDEDFORCE_H_*/ #endif /*OPENMM_CUSTOMNONBONDEDFORCE_H_*/
...@@ -186,6 +186,35 @@ void CustomNonbondedForce::setExceptionParameters(int index, int particle1, int ...@@ -186,6 +186,35 @@ void CustomNonbondedForce::setExceptionParameters(int index, int particle1, int
exceptions[index].parameters = parameters; exceptions[index].parameters = parameters;
} }
int CustomNonbondedForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) {
if (max <= min)
throw OpenMMException("CustomNonbondedForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomNonbondedForce: a tabulated function must have at least two points");
functions.push_back(FunctionInfo(name, values, min, max, interpolating));
return functions.size()-1;
}
void CustomNonbondedForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const {
name = functions[index].name;
values = functions[index].values;
min = functions[index].min;
max = functions[index].max;
interpolating = functions[index].interpolating;
}
void CustomNonbondedForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) {
if (max <= min)
throw OpenMMException("CustomNonbondedForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomNonbondedForce: a tabulated function must have at least two points");
functions[index].name = name;
functions[index].values = values;
functions[index].min = min;
functions[index].max = max;
functions[index].interpolating = interpolating;
}
ForceImpl* CustomNonbondedForce::createImpl() { ForceImpl* CustomNonbondedForce::createImpl() {
return new CustomNonbondedForceImpl(*this); return new CustomNonbondedForceImpl(*this);
} }
...@@ -53,6 +53,7 @@ ...@@ -53,6 +53,7 @@
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/Integrator.h" #include "openmm/Integrator.h"
#include "SimTKUtilities/SimTKOpenMMUtilities.h" #include "SimTKUtilities/SimTKOpenMMUtilities.h"
#include "lepton/CustomFunction.h"
#include "lepton/Parser.h" #include "lepton/Parser.h"
#include "lepton/ParsedExpression.h" #include "lepton/ParsedExpression.h"
#include <cmath> #include <cmath>
...@@ -473,6 +474,65 @@ double ReferenceCalcNonbondedForceKernel::executeEnergy(ContextImpl& context) { ...@@ -473,6 +474,65 @@ double ReferenceCalcNonbondedForceKernel::executeEnergy(ContextImpl& context) {
return energy; return energy;
} }
class ReferenceCalcCustomNonbondedForceKernel::TabulatedFunction : public Lepton::CustomFunction {
public:
TabulatedFunction(double min, double max, const vector<double>& values, bool interpolating) :
min(min), max(max), values(values), interpolating(interpolating) {
}
int getNumArguments() const {
return 1;
}
/**
* Given the function argument, find the local spline coefficients.
*/
void findCoefficients(double& x, double* coeff) const {
int length = values.size();
double spacing = (max-min)/(length-1);
int index = std::floor((x-min)/spacing);
double points[4];
points[0] = (index == 0 ? 2*values[0]-values[1] : values[index-1]);
points[1] = values[index];
points[2] = (index > length-2 ? values[length-1] : values[index+1]);
points[3] = (index > length-3 ? 2*values[length-1]-values[length-2] : values[index+2]);
if (interpolating) {
coeff[0] = points[1];
coeff[1] = 0.5*(-points[0]+points[2]);
coeff[2] = 0.5*(2.0*points[0]-5.0*points[1]+4.0*points[2]-points[3]);
coeff[3] = 0.5*(-points[0]+3.0*points[1]-3.0*points[2]+points[3]);
}
else {
coeff[0] = (points[0]+4.0*points[1]+points[2])/6.0;
coeff[1] = (-3.0*points[0]+3.0*points[2])/6.0;
coeff[2] = (3.0*points[0]-6.0*points[1]+3.0*points[2])/6.0;
coeff[3] = (-points[0]+3.0*points[1]-3.0*points[2]+points[3])/6.0;
}
x = (x-min)/spacing-index;
}
double evaluate(const double* arguments) const {
double x = arguments[0];
if (x < min || x > max)
return 0.0;
double coeff[4];
findCoefficients(x, coeff);
return coeff[0]+x*(coeff[1]+x*(coeff[2]+x*coeff[3]));
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
double x = arguments[0];
if (x < min || x > max)
return 0.0;
double coeff[4];
findCoefficients(x, coeff);
double scale = (values.size()-1)/(max-min);
return scale*(coeff[1]+x*(2.0*coeff[2]+x*3.0*coeff[3]));
}
CustomFunction* clone() const {
return new TabulatedFunction(min, max, values, interpolating);
}
double min, max;
vector<double> values;
bool interpolating;
};
ReferenceCalcCustomNonbondedForceKernel::~ReferenceCalcCustomNonbondedForceKernel() { ReferenceCalcCustomNonbondedForceKernel::~ReferenceCalcCustomNonbondedForceKernel() {
disposeRealArray(particleParamArray, numParticles); disposeRealArray(particleParamArray, numParticles);
disposeIntArray(exclusionArray, numParticles); disposeIntArray(exclusionArray, numParticles);
...@@ -540,17 +600,34 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c ...@@ -540,17 +600,34 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
else else
neighborList = new NeighborList(); neighborList = new NeighborList();
// Create custom functions for the tabulated functions.
map<string, Lepton::CustomFunction*> functions;
for (int i = 0; i < force.getNumFunctions(); i++) {
string name;
vector<double> values;
double min, max;
bool interpolating;
force.getFunctionParameters(i, name, values, min, max, interpolating);
functions[name] = new TabulatedFunction(min, max, values, interpolating);
}
// Parse the various expressions used to calculate the force. // Parse the various expressions used to calculate the force.
Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction()).optimize(); Lepton::ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
energyExpression = expression.createProgram(); energyExpression = expression.createProgram();
forceExpression = expression.differentiate("r").optimize().createProgram(); forceExpression = expression.differentiate("r").optimize().createProgram();
for (int i = 0; i < numParameters; i++) { for (int i = 0; i < numParameters; i++) {
parameterNames.push_back(force.getParameterName(i)); parameterNames.push_back(force.getParameterName(i));
combiningRules.push_back(Lepton::Parser::parse(force.getParameterCombiningRule(i)).optimize().createProgram()); combiningRules.push_back(Lepton::Parser::parse(force.getParameterCombiningRule(i), functions).optimize().createProgram());
} }
for (int i = 0; i < force.getNumGlobalParameters(); i++) for (int i = 0; i < force.getNumGlobalParameters(); i++)
globalParameterNames.push_back(force.getGlobalParameterName(i)); globalParameterNames.push_back(force.getGlobalParameterName(i));
// Delete the custom functions.
for (map<string, Lepton::CustomFunction*>::iterator iter = functions.begin(); iter != functions.end(); iter++)
delete iter->second;
} }
void ReferenceCalcCustomNonbondedForceKernel::executeForces(ContextImpl& context) { void ReferenceCalcCustomNonbondedForceKernel::executeForces(ContextImpl& context) {
......
...@@ -314,6 +314,7 @@ private: ...@@ -314,6 +314,7 @@ private:
std::vector<Lepton::ExpressionProgram> combiningRules; std::vector<Lepton::ExpressionProgram> combiningRules;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
NeighborList* neighborList; NeighborList* neighborList;
class TabulatedFunction;
}; };
/** /**
......
...@@ -202,6 +202,38 @@ void testPeriodic() { ...@@ -202,6 +202,38 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
} }
void testTabulatedFunction(bool interpolating) {
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r)+1");
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0, interpolating);
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
double tol = 0.01;
for (int i = 1; i < 30; i++) {
double x = (7.0/30.0)*i;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double force = (x < 1.0 || x > 6.0 ? 0.0 : -std::cos(x-1.0));
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0;
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
}
}
int main() { int main() {
try { try {
testSimpleExpression(); testSimpleExpression();
...@@ -209,6 +241,8 @@ int main() { ...@@ -209,6 +241,8 @@ int main() {
testExceptions(); testExceptions();
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
testTabulatedFunction(true);
testTabulatedFunction(false);
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment