Commit 4e1e1b11 authored by Peter Eastman's avatar Peter Eastman
Browse files

Custom functions are now represented by natural splines

parent 06a98e93
......@@ -134,7 +134,7 @@ namespace OpenMM {
* an expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
*
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of
* values, and an interpolating or approximating spline is created from them. That function can then appear in expressions.
* values, and a natural spline is created from them. That function can then appear in expressions.
*/
class OPENMM_EXPORT CustomGBForce : public Force {
......@@ -458,11 +458,9 @@ public:
* The function is assumed to be zero for x < min or x > max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
* @return the index of the function that was added
*/
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating);
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Get the parameters for a tabulated function that may appear in the energy expression.
*
......@@ -472,10 +470,8 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const;
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/**
* Set the parameters for a tabulated function that may appear in algebraic expressions.
*
......@@ -485,10 +481,8 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating);
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
protected:
ForceImpl* createImpl();
private:
......@@ -573,11 +567,10 @@ public:
std::string name;
std::vector<double> values;
double min, max;
bool interpolating;
FunctionInfo() {
}
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) :
name(name), values(values), min(min), max(max), interpolating(interpolating) {
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max) :
name(name), values(values), min(min), max(max) {
}
};
......
......@@ -92,7 +92,7 @@ namespace OpenMM {
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise.
*
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of
* values, and an interpolating or approximating spline is created from them. That function can then appear in the expression.
* values, and a natural spline is created from them. That function can then appear in the expression.
*/
class OPENMM_EXPORT CustomHbondForce : public Force {
......@@ -378,11 +378,9 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
* @return the index of the function that was added
*/
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating);
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Get the parameters for a tabulated function that may appear in the energy expression.
*
......@@ -392,10 +390,8 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const;
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/**
* Set the parameters for a tabulated function that may appear in algebraic expressions.
*
......@@ -405,10 +401,8 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating);
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
protected:
ForceImpl* createImpl();
private:
......@@ -495,11 +489,10 @@ public:
std::string name;
std::vector<double> values;
double min, max;
bool interpolating;
FunctionInfo() {
}
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) :
name(name), values(values), min(min), max(max), interpolating(interpolating) {
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max) :
name(name), values(values), min(min), max(max) {
}
};
......
......@@ -80,7 +80,7 @@ namespace OpenMM {
* the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
*
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of
* values, and an interpolating or approximating spline is created from them. That function can then appear in the expression.
* values, and a natural spline is created from them. That function can then appear in the expression.
*/
class OPENMM_EXPORT CustomNonbondedForce : public Force {
......@@ -282,11 +282,9 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
* @return the index of the function that was added
*/
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating);
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/**
* Get the parameters for a tabulated function that may appear in the energy expression.
*
......@@ -296,10 +294,8 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const;
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/**
* Set the parameters for a tabulated function that may appear in algebraic expressions.
*
......@@ -309,10 +305,8 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating);
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
protected:
ForceImpl* createImpl();
private:
......@@ -395,11 +389,10 @@ public:
std::string name;
std::vector<double> values;
double min, max;
bool interpolating;
FunctionInfo() {
}
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) :
name(name), values(values), min(min), max(max), interpolating(interpolating) {
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max) :
name(name), values(values), min(min), max(max) {
}
};
......
......@@ -158,24 +158,23 @@ void CustomGBForce::setExclusionParticles(int index, int particle1, int particle
exclusions[index].particle2 = particle2;
}
int CustomGBForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) {
int CustomGBForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("CustomGBForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomGBForce: a tabulated function must have at least two points");
functions.push_back(FunctionInfo(name, values, min, max, interpolating));
functions.push_back(FunctionInfo(name, values, min, max));
return functions.size()-1;
}
void CustomGBForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const {
void CustomGBForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const {
name = functions[index].name;
values = functions[index].values;
min = functions[index].min;
max = functions[index].max;
interpolating = functions[index].interpolating;
}
void CustomGBForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) {
void CustomGBForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("CustomGBForce: max <= min for a tabulated function.");
if (values.size() < 2)
......@@ -184,7 +183,6 @@ void CustomGBForce::setFunctionParameters(int index, const std::string& name, co
functions[index].values = values;
functions[index].min = min;
functions[index].max = max;
functions[index].interpolating = interpolating;
}
ForceImpl* CustomGBForce::createImpl() {
......
......@@ -172,24 +172,23 @@ void CustomHbondForce::setExclusionParticles(int index, int donor, int acceptor)
exclusions[index].acceptor = acceptor;
}
int CustomHbondForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) {
int CustomHbondForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("CustomHbondForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomHbondForce: a tabulated function must have at least two points");
functions.push_back(FunctionInfo(name, values, min, max, interpolating));
functions.push_back(FunctionInfo(name, values, min, max));
return functions.size()-1;
}
void CustomHbondForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const {
void CustomHbondForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const {
name = functions[index].name;
values = functions[index].values;
min = functions[index].min;
max = functions[index].max;
interpolating = functions[index].interpolating;
}
void CustomHbondForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) {
void CustomHbondForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("CustomHbondForce: max <= min for a tabulated function.");
if (values.size() < 2)
......@@ -198,7 +197,6 @@ void CustomHbondForce::setFunctionParameters(int index, const std::string& name,
functions[index].values = values;
functions[index].min = min;
functions[index].max = max;
functions[index].interpolating = interpolating;
}
ForceImpl* CustomHbondForce::createImpl() {
......
......@@ -134,24 +134,23 @@ void CustomNonbondedForce::setExclusionParticles(int index, int particle1, int p
exclusions[index].particle2 = particle2;
}
int CustomNonbondedForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) {
int CustomNonbondedForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("CustomNonbondedForce: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("CustomNonbondedForce: a tabulated function must have at least two points");
functions.push_back(FunctionInfo(name, values, min, max, interpolating));
functions.push_back(FunctionInfo(name, values, min, max));
return functions.size()-1;
}
void CustomNonbondedForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const {
void CustomNonbondedForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const {
name = functions[index].name;
values = functions[index].values;
min = functions[index].min;
max = functions[index].max;
interpolating = functions[index].interpolating;
}
void CustomNonbondedForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) {
void CustomNonbondedForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("CustomNonbondedForce: max <= min for a tabulated function.");
if (values.size() < 2)
......@@ -160,7 +159,6 @@ void CustomNonbondedForce::setFunctionParameters(int index, const std::string& n
functions[index].values = values;
functions[index].min = min;
functions[index].max = max;
functions[index].interpolating = interpolating;
}
ForceImpl* CustomNonbondedForce::createImpl() {
......
......@@ -41,9 +41,15 @@ void SplineFitter::createNaturalSpline(const vector<double>& x, const vector<dou
int n = x.size();
if (y.size() != n)
throw OpenMMException("createNaturalSpline: x and y vectors must have same length");
if (n < 3)
throw OpenMMException("createNaturalSpline: the length of the input array must be at least 3");
if (n < 2)
throw OpenMMException("createNaturalSpline: the length of the input array must be at least 2");
deriv.resize(n);
if (n == 2) {
// This is just a straight line.
deriv[0] = 0;
deriv[1] = 0;
}
// Create the system of equations to solve.
......
......@@ -940,9 +940,8 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
string name;
vector<double> values;
double min, max;
bool interpolating;
force.getFunctionParameters(i, name, values, min, max, interpolating);
gpuSetTabulatedFunction(gpu, i, name, values, min, max, interpolating);
force.getFunctionParameters(i, name, values, min, max);
gpuSetTabulatedFunction(gpu, i, name, values, min, max);
}
// Record information for the expressions.
......
......@@ -50,6 +50,7 @@ using namespace std;
#include "cudaKernels.h"
#include "hilbert.h"
#include "openmm/OpenMMException.h"
#include "openmm/internal/SplineFitter.h"
#include "quern.h"
#include "Lepton.h"
#include "rng.h"
......@@ -614,7 +615,7 @@ void gpuSetNonbondedCutoff(gpuContext gpu, float cutoffDistance, float solventDi
}
extern "C"
void gpuSetTabulatedFunction(gpuContext gpu, int index, const string& name, const vector<double>& values, double min, double max, bool interpolating)
void gpuSetTabulatedFunction(gpuContext gpu, int index, const string& name, const vector<double>& values, double min, double max)
{
if (index < 0 || index >= MAX_TABULATED_FUNCTIONS) {
stringstream str;
......@@ -631,32 +632,15 @@ void gpuSetTabulatedFunction(gpuContext gpu, int index, const string& name, cons
gpu->tabulatedFunctions[index].max = max;
gpu->tabulatedFunctionsChanged = true;
// First create a padded set of function values.
// Compute the spline coefficients.
vector<double> padded(values.size()+2);
padded[0] = 2*values[0]-values[1];
for (int i = 0; i < (int) values.size(); i++)
padded[i+1] = values[i];
padded[padded.size()-1] = 2*values[values.size()-1]-values[values.size()-2];
// Now compute the spline coefficients.
for (int i = 0; i < (int) values.size()-1; i++) {
float4 c;
if (interpolating) {
c.x = (float) padded[i+1];
c.y = (float) (0.5*(-padded[i]+padded[i+2]));
c.z = (float) (0.5*(2.0*padded[i]-5.0*padded[i+1]+4.0*padded[i+2]-padded[i+3]));
c.w = (float) (0.5*(-padded[i]+3.0*padded[i+1]-3.0*padded[i+2]+padded[i+3]));
}
else {
c.x = (float) ((padded[i]+4.0*padded[i+1]+padded[i+2])/6.0);
c.y = (float) ((-3.0*padded[i]+3.0*padded[i+2])/6.0);
c.z = (float) ((3.0*padded[i]-6.0*padded[i+1]+3.0*padded[i+2])/6.0);
c.w = (float) ((-padded[i]+3.0*padded[i+1]-3.0*padded[i+2]+padded[i+3])/6.0);
}
(*coeff)[i] = c;
}
int numValues = values.size();
vector<double> x(numValues), derivs;
for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1);
OpenMM::SplineFitter::createNaturalSpline(x, values, derivs);
for (int i = 0; i < (int) values.size()-1; i++)
(*coeff)[i] = make_float4((float) values[i], (float) values[i+1], (float) (derivs[i]/6.0), (float) (derivs[i+1]/6.0));
coeff->Upload();
}
......@@ -914,7 +898,7 @@ void gpuSetCustomNonbondedParameters(gpuContext gpu, const vector<vector<double>
for (int i = 0; i < MAX_TABULATED_FUNCTIONS; i++) {
gpuTabulatedFunction& func = gpu->tabulatedFunctions[i];
if (func.coefficients != NULL) {
(*gpu->psTabulatedFunctionParams)[i] = make_float4((float) func.min, (float) func.max, (float) (func.coefficients->_length/(func.max-func.min)), 0.0f);
(*gpu->psTabulatedFunctionParams)[i] = make_float4((float) func.min, (float) func.max, (float) (func.coefficients->_length/(func.max-func.min)), (float) (func.coefficients->_length-1));
functions[func.name] = fp;
}
}
......
......@@ -219,7 +219,7 @@ extern "C"
void gpuSetNonbondedCutoff(gpuContext gpu, float cutoffDistance, float solventDielectric);
extern "C"
void gpuSetTabulatedFunction(gpuContext gpu, int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating);
void gpuSetTabulatedFunction(gpuContext gpu, int index, const std::string& name, const std::vector<double>& values, double min, double max);
extern "C"
void gpuSetCustomBondParameters(gpuContext gpu, const std::vector<int>& bondAtom1, const std::vector<int>& bondAtom2, const std::vector<std::vector<double> >& bondParams,
......
......@@ -97,7 +97,9 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float*
STACK(stackPointer) = 0.0f;
else
{
int index = floor((x-params.x)*params.z);
x = (x-params.x)*params.z;
int index = floor(x);
index = min(index, (int) params.w);
float4 coeff;
if (function == 0)
coeff = tex1Dfetch(texRef0, index);
......@@ -107,11 +109,12 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float*
coeff = tex1Dfetch(texRef2, index);
else
coeff = tex1Dfetch(texRef3, index);
x = (x-params.x)*params.z-index;
float b = x-index;
float a = 1.0f-b;
if (op == CUSTOM)
STACK(stackPointer) = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));
STACK(stackPointer) = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);
else
STACK(stackPointer) = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;
STACK(stackPointer) = (coeff.y-coeff.x)*params.z-((3.0f*a*a-1.0f)*coeff.z+(4.0f*b*b-1.0f)*coeff.w)/params.z;
}
}
}
......
......@@ -203,7 +203,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
}
void testTabulatedFunction(bool interpolating) {
void testTabulatedFunction() {
CudaPlatform platform;
System system;
system.addParticle(1.0);
......@@ -215,7 +215,7 @@ void testTabulatedFunction(bool interpolating) {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0, interpolating);
forceField->addFunction("fn", table, 1.0, 6.0);
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -233,6 +233,14 @@ void testTabulatedFunction(bool interpolating) {
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
}
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
}
void testCoulombLennardJones() {
......@@ -327,8 +335,7 @@ int main() {
testExclusions();
testCutoff();
testPeriodic();
testTabulatedFunction(true);
testTabulatedFunction(false);
testTabulatedFunction();
testCoulombLennardJones();
}
catch(const exception& e) {
......
......@@ -26,6 +26,7 @@
#include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h"
#include "openmm/internal/SplineFitter.h"
#include "lepton/Operation.h"
using namespace OpenMM;
......@@ -120,14 +121,16 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n";
out << "int index = (int) (floor((x-params.x)*params.z));\n";
out << "x = (x-params.x)*params.z;\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) params.w);\n";
out << "float4 coeff = " << functions[i].second << "[index];\n";
out << "x = (x-params.x)*params.z-index;\n";
out << "float b = x-index;\n";
out << "float a = 1.0f-b;\n";
if (valueNode != NULL)
out << valueName << " = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));\n";
out << valueName << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n";
if (derivNode != NULL)
out << derivName << " = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;\n";
out << derivName << " = (coeff.y-coeff.x)*params.z-((3.0f*a*a-1.0f)*coeff.z+(4.0f*b*b-1.0f)*coeff.w)/params.z;\n";
out << "}\n";
out << "}";
break;
......@@ -338,29 +341,16 @@ void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node
findRelatedPowers(node, searchNode.getChildren()[i], powers);
}
vector<mm_float4> OpenCLExpressionUtilities::computeFunctionCoefficients(const vector<double>& values, bool interpolating) {
// First create a padded set of function values.
vector<mm_float4> OpenCLExpressionUtilities::computeFunctionCoefficients(const vector<double>& values, double min, double max) {
// Compute the spline coefficients.
vector<double> padded(values.size()+2);
padded[0] = 2*values[0]-values[1];
for (int i = 0; i < (int) values.size(); i++)
padded[i+1] = values[i];
padded[padded.size()-1] = 2*values[values.size()-1]-values[values.size()-2];
// Now compute the spline coefficients.
vector<mm_float4> f(values.size()-1);
for (int i = 0; i < (int) values.size()-1; i++) {
if (interpolating)
f[i] = mm_float4((cl_float) padded[i+1],
(cl_float) (0.5*(-padded[i]+padded[i+2])),
(cl_float) (0.5*(2.0*padded[i]-5.0*padded[i+1]+4.0*padded[i+2]-padded[i+3])),
(cl_float) (0.5*(-padded[i]+3.0*padded[i+1]-3.0*padded[i+2]+padded[i+3])));
else
f[i] = mm_float4((cl_float) ((padded[i]+4.0*padded[i+1]+padded[i+2])/6.0),
(cl_float) ((-3.0*padded[i]+3.0*padded[i+2])/6.0),
(cl_float) ((3.0*padded[i]-6.0*padded[i+1]+3.0*padded[i+2])/6.0),
(cl_float) ((-padded[i]+3.0*padded[i+1]-3.0*padded[i+2]+padded[i+3])/6.0));
}
int numValues = values.size();
vector<double> x(numValues), derivs;
for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createNaturalSpline(x, values, derivs);
vector<mm_float4> f(numValues-1);
for (int i = 0; i < (int) values.size()-1; i++)
f[i] = mm_float4((cl_float) values[i], (cl_float) values[i+1], (cl_float) (derivs[i]/6.0), (cl_float) (derivs[i+1]/6.0));
return f;
}
......@@ -60,10 +60,11 @@ public:
* Calculate the spline coefficients for a tabulated function that appears in expressions.
*
* @param values the tabulated values of the function
* @param interpolating true if an interpolating spline should be used, false if an approximating spline should be used
* @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @return the spline coefficients
*/
static std::vector<mm_float4> computeFunctionCoefficients(const std::vector<double>& values, bool interpolating);
static std::vector<mm_float4> computeFunctionCoefficients(const std::vector<double>& values, double min, double max);
/**
* Convert a number to a string in a format suitable for including in a kernel.
*/
......
......@@ -1533,13 +1533,12 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
string name;
vector<double> values;
double min, max;
bool interpolating;
force.getFunctionParameters(i, name, values, min, max, interpolating);
force.getFunctionParameters(i, name, values, min, max);
string arrayName = prefix+"table"+intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp;
tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), values.size()-2);
vector<mm_float4> f = OpenCLExpressionUtilities::computeFunctionCoefficients(values, interpolating);
vector<mm_float4> f = OpenCLExpressionUtilities::computeFunctionCoefficients(values, min, max);
tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
......@@ -1874,13 +1873,12 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
string name;
vector<double> values;
double min, max;
bool interpolating;
force.getFunctionParameters(i, name, values, min, max, interpolating);
force.getFunctionParameters(i, name, values, min, max);
string arrayName = prefix+"table"+intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp;
tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), values.size()-2);
vector<mm_float4> f = OpenCLExpressionUtilities::computeFunctionCoefficients(values, interpolating);
vector<mm_float4> f = OpenCLExpressionUtilities::computeFunctionCoefficients(values, min, max);
tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
......@@ -2917,13 +2915,12 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
string name;
vector<double> values;
double min, max;
bool interpolating;
force.getFunctionParameters(i, name, values, min, max, interpolating);
force.getFunctionParameters(i, name, values, min, max);
string arrayName = "table"+intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp;
tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), values.size()-2);
vector<mm_float4> f = OpenCLExpressionUtilities::computeFunctionCoefficients(values, interpolating);
vector<mm_float4> f = OpenCLExpressionUtilities::computeFunctionCoefficients(values, min, max);
tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
tableArgs << ", __global float4* " << arrayName;
......
......@@ -230,7 +230,7 @@ void testMembrane() {
ASSERT_EQUAL_TOL(norm, (state2.getPotentialEnergy()-state.getPotentialEnergy())/stepSize, 1e-2);
}
void testTabulatedFunction(bool interpolating) {
void testTabulatedFunction() {
OpenCLPlatform platform;
System system;
system.addParticle(1.0);
......@@ -244,7 +244,7 @@ void testTabulatedFunction(bool interpolating) {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
force->addFunction("fn", table, 1.0, 6.0, interpolating);
force->addFunction("fn", table, 1.0, 6.0);
system.addForce(force);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -261,6 +261,14 @@ void testTabulatedFunction(bool interpolating) {
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
}
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
}
void testMultipleChainRules() {
......@@ -424,8 +432,7 @@ int main() {
testOBC(GBSAOBCForce::CutoffNonPeriodic, CustomGBForce::CutoffNonPeriodic);
testOBC(GBSAOBCForce::CutoffPeriodic, CustomGBForce::CutoffPeriodic);
testMembrane();
testTabulatedFunction(true);
testTabulatedFunction(false);
testTabulatedFunction();
testMultipleChainRules();
testPositionDependence();
testExclusions();
......
......@@ -195,7 +195,7 @@ void testCustomFunctions() {
vector<double> function(2);
function[0] = 0;
function[1] = 1;
custom->addFunction("foo", function, 0, 10, true);
custom->addFunction("foo", function, 0, 10);
system.addForce(custom);
Context context(system, integrator, platform);
vector<Vec3> positions(3);
......
......@@ -242,7 +242,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
}
void testTabulatedFunction(bool interpolating) {
void testTabulatedFunction() {
OpenCLPlatform platform;
System system;
system.addParticle(1.0);
......@@ -254,7 +254,7 @@ void testTabulatedFunction(bool interpolating) {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0, interpolating);
forceField->addFunction("fn", table, 1.0, 6.0);
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -272,6 +272,14 @@ void testTabulatedFunction(bool interpolating) {
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
}
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
}
void testCoulombLennardJones() {
......@@ -357,8 +365,7 @@ int main() {
testExclusions();
testCutoff();
testPeriodic();
testTabulatedFunction(true);
testTabulatedFunction(false);
testTabulatedFunction();
testCoulombLennardJones();
}
catch(const exception& e) {
......
......@@ -63,6 +63,7 @@
#include "openmm/internal/CustomHbondForceImpl.h"
#include "openmm/internal/CMAPTorsionForceImpl.h"
#include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/SplineFitter.h"
#include "openmm/Integrator.h"
#include "openmm/OpenMMException.h"
#include "SimTKUtilities/SimTKOpenMMUtilities.h"
......@@ -713,61 +714,34 @@ double ReferenceCalcNonbondedForceKernel::execute(ContextImpl& context, bool inc
class ReferenceTabulatedFunction : public Lepton::CustomFunction {
public:
ReferenceTabulatedFunction(double min, double max, const vector<double>& values, bool interpolating) :
min(min), max(max), values(values), interpolating(interpolating) {
ReferenceTabulatedFunction(double min, double max, const vector<double>& values) :
min(min), max(max), values(values) {
int numValues = values.size();
x.resize(numValues);
for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createNaturalSpline(x, values, derivs);
}
int getNumArguments() const {
return 1;
}
/**
* Given the function argument, find the local spline coefficients.
*/
void findCoefficients(double& x, double* coeff) const {
int length = values.size();
double scale = (length-1)/(max-min);
int index = (int) std::floor((x-min)*scale);
double points[4];
points[0] = (index == 0 ? 2*values[0]-values[1] : values[index-1]);
points[1] = values[index];
points[2] = (index > length-2 ? values[length-1] : values[index+1]);
points[3] = (index > length-3 ? 2*values[length-1]-values[length-2] : values[index+2]);
if (interpolating) {
coeff[0] = points[1];
coeff[1] = 0.5*(-points[0]+points[2]);
coeff[2] = 0.5*(2.0*points[0]-5.0*points[1]+4.0*points[2]-points[3]);
coeff[3] = 0.5*(-points[0]+3.0*points[1]-3.0*points[2]+points[3]);
}
else {
coeff[0] = (points[0]+4.0*points[1]+points[2])/6.0;
coeff[1] = (-3.0*points[0]+3.0*points[2])/6.0;
coeff[2] = (3.0*points[0]-6.0*points[1]+3.0*points[2])/6.0;
coeff[3] = (-points[0]+3.0*points[1]-3.0*points[2]+points[3])/6.0;
}
x = (x-min)*scale-index;
}
double evaluate(const double* arguments) const {
double x = arguments[0];
if (x < min || x > max)
double t = arguments[0];
if (t < min || t > max)
return 0.0;
double coeff[4];
findCoefficients(x, coeff);
return coeff[0]+x*(coeff[1]+x*(coeff[2]+x*coeff[3]));
return SplineFitter::evaluateSpline(x, values, derivs, t);
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
double x = arguments[0];
if (x < min || x > max)
double t = arguments[0];
if (t < min || t > max)
return 0.0;
double coeff[4];
findCoefficients(x, coeff);
double scale = (values.size()-1)/(max-min);
return scale*(coeff[1]+x*(2.0*coeff[2]+x*3.0*coeff[3])); // We assume a first derivative, because that's the only order ever used by CustomNonbondedForce.
return SplineFitter::evaluateSplineDerivative(x, values, derivs, t);
}
CustomFunction* clone() const {
return new ReferenceTabulatedFunction(min, max, values, interpolating);
return new ReferenceTabulatedFunction(min, max, values);
}
double min, max;
vector<double> values;
bool interpolating;
vector<double> x, values, derivs;
};
ReferenceCalcCustomNonbondedForceKernel::~ReferenceCalcCustomNonbondedForceKernel() {
......@@ -822,9 +796,8 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
string name;
vector<double> values;
double min, max;
bool interpolating;
force.getFunctionParameters(i, name, values, min, max, interpolating);
functions[name] = new ReferenceTabulatedFunction(min, max, values, interpolating);
force.getFunctionParameters(i, name, values, min, max);
functions[name] = new ReferenceTabulatedFunction(min, max, values);
}
// Parse the various expressions used to calculate the force.
......@@ -1010,9 +983,8 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
string name;
vector<double> values;
double min, max;
bool interpolating;
force.getFunctionParameters(i, name, values, min, max, interpolating);
functions[name] = new ReferenceTabulatedFunction(min, max, values, interpolating);
force.getFunctionParameters(i, name, values, min, max);
functions[name] = new ReferenceTabulatedFunction(min, max, values);
}
// Parse the expressions for computed values.
......@@ -1204,9 +1176,8 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
string name;
vector<double> values;
double min, max;
bool interpolating;
force.getFunctionParameters(i, name, values, min, max, interpolating);
functions[name] = new ReferenceTabulatedFunction(min, max, values, interpolating);
force.getFunctionParameters(i, name, values, min, max);
functions[name] = new ReferenceTabulatedFunction(min, max, values);
}
// Parse the expression and create the object used to calculate the interaction.
......
......@@ -232,7 +232,7 @@ void testMembrane() {
ASSERT_EQUAL_TOL(norm, (state2.getPotentialEnergy()-state.getPotentialEnergy())/stepSize, 1e-2);
}
void testTabulatedFunction(bool interpolating) {
void testTabulatedFunction() {
ReferencePlatform platform;
System system;
system.addParticle(1.0);
......@@ -246,7 +246,7 @@ void testTabulatedFunction(bool interpolating) {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
force->addFunction("fn", table, 1.0, 6.0, interpolating);
force->addFunction("fn", table, 1.0, 6.0);
system.addForce(force);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -865,8 +865,7 @@ int main() {
testOBC(GBSAOBCForce::CutoffNonPeriodic, CustomGBForce::CutoffNonPeriodic);
testOBC(GBSAOBCForce::CutoffPeriodic, CustomGBForce::CutoffPeriodic);
testMembrane();
testTabulatedFunction(true);
testTabulatedFunction(false);
testTabulatedFunction();
testMultipleChainRules();
testPositionDependence();
testExclusions();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment