Commit 4e1e1b11 authored by Peter Eastman's avatar Peter Eastman
Browse files

Custom functions are now represented by natural splines

parent 06a98e93
...@@ -134,7 +134,7 @@ namespace OpenMM { ...@@ -134,7 +134,7 @@ namespace OpenMM {
* an expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator. * an expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of * In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of
* values, and an interpolating or approximating spline is created from them. That function can then appear in expressions. * values, and a natural spline is created from them. That function can then appear in expressions.
*/ */
class OPENMM_EXPORT CustomGBForce : public Force { class OPENMM_EXPORT CustomGBForce : public Force {
...@@ -458,11 +458,9 @@ public: ...@@ -458,11 +458,9 @@ public:
* The function is assumed to be zero for x < min or x > max. * The function is assumed to be zero for x < min or x > max.
* @param min the value of the independent variable corresponding to the first element of values * @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values * @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
* @return the index of the function that was added * @return the index of the function that was added
*/ */
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating); int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Get the parameters for a tabulated function that may appear in the energy expression. * Get the parameters for a tabulated function that may appear in the energy expression.
* *
...@@ -472,10 +470,8 @@ public: ...@@ -472,10 +470,8 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max. * The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values * @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values * @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/ */
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const; void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/** /**
* Set the parameters for a tabulated function that may appear in algebraic expressions. * Set the parameters for a tabulated function that may appear in algebraic expressions.
* *
...@@ -485,10 +481,8 @@ public: ...@@ -485,10 +481,8 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max. * The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values * @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values * @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
protected: protected:
ForceImpl* createImpl(); ForceImpl* createImpl();
private: private:
...@@ -573,11 +567,10 @@ public: ...@@ -573,11 +567,10 @@ public:
std::string name; std::string name;
std::vector<double> values; std::vector<double> values;
double min, max; double min, max;
bool interpolating;
FunctionInfo() { FunctionInfo() {
} }
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) : FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max) :
name(name), values(values), min(min), max(max), interpolating(interpolating) { name(name), values(values), min(min), max(max) {
} }
}; };
......
...@@ -92,7 +92,7 @@ namespace OpenMM { ...@@ -92,7 +92,7 @@ namespace OpenMM {
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of * In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of
* values, and an interpolating or approximating spline is created from them. That function can then appear in the expression. * values, and a natural spline is created from them. That function can then appear in the expression.
*/ */
class OPENMM_EXPORT CustomHbondForce : public Force { class OPENMM_EXPORT CustomHbondForce : public Force {
...@@ -378,11 +378,9 @@ public: ...@@ -378,11 +378,9 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max. * The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values * @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values * @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
* @return the index of the function that was added * @return the index of the function that was added
*/ */
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating); int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Get the parameters for a tabulated function that may appear in the energy expression. * Get the parameters for a tabulated function that may appear in the energy expression.
* *
...@@ -392,10 +390,8 @@ public: ...@@ -392,10 +390,8 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max. * The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values * @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values * @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/ */
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const; void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/** /**
* Set the parameters for a tabulated function that may appear in algebraic expressions. * Set the parameters for a tabulated function that may appear in algebraic expressions.
* *
...@@ -405,10 +401,8 @@ public: ...@@ -405,10 +401,8 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max. * The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values * @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values * @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
protected: protected:
ForceImpl* createImpl(); ForceImpl* createImpl();
private: private:
...@@ -495,11 +489,10 @@ public: ...@@ -495,11 +489,10 @@ public:
std::string name; std::string name;
std::vector<double> values; std::vector<double> values;
double min, max; double min, max;
bool interpolating;
FunctionInfo() { FunctionInfo() {
} }
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) : FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max) :
name(name), values(values), min(min), max(max), interpolating(interpolating) { name(name), values(values), min(min), max(max) {
} }
}; };
......
...@@ -80,7 +80,7 @@ namespace OpenMM { ...@@ -80,7 +80,7 @@ namespace OpenMM {
* the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator. * the expression may also involve intermediate quantities that are defined following the main expression, using ";" as a separator.
* *
* In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of * In addition, you can call addFunction() to define a new function based on tabulated values. You specify a vector of
* values, and an interpolating or approximating spline is created from them. That function can then appear in the expression. * values, and a natural spline is created from them. That function can then appear in the expression.
*/ */
class OPENMM_EXPORT CustomNonbondedForce : public Force { class OPENMM_EXPORT CustomNonbondedForce : public Force {
...@@ -282,11 +282,9 @@ public: ...@@ -282,11 +282,9 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max. * The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values * @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values * @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
* @return the index of the function that was added * @return the index of the function that was added
*/ */
int addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating); int addFunction(const std::string& name, const std::vector<double>& values, double min, double max);
/** /**
* Get the parameters for a tabulated function that may appear in the energy expression. * Get the parameters for a tabulated function that may appear in the energy expression.
* *
...@@ -296,10 +294,8 @@ public: ...@@ -296,10 +294,8 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max. * The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values * @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values * @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/ */
void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const; void getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const;
/** /**
* Set the parameters for a tabulated function that may appear in algebraic expressions. * Set the parameters for a tabulated function that may appear in algebraic expressions.
* *
...@@ -309,10 +305,8 @@ public: ...@@ -309,10 +305,8 @@ public:
* The function is assumed to be zero for x &lt; min or x &gt; max. * The function is assumed to be zero for x &lt; min or x &gt; max.
* @param min the value of the independent variable corresponding to the first element of values * @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values * @param max the value of the independent variable corresponding to the last element of values
* @param interpolating if true, an interpolating (Catmull-Rom) spline will be used to represent the function.
* If false, an approximating spline (B-spline) will be used.
*/ */
void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating); void setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max);
protected: protected:
ForceImpl* createImpl(); ForceImpl* createImpl();
private: private:
...@@ -395,11 +389,10 @@ public: ...@@ -395,11 +389,10 @@ public:
std::string name; std::string name;
std::vector<double> values; std::vector<double> values;
double min, max; double min, max;
bool interpolating;
FunctionInfo() { FunctionInfo() {
} }
FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) : FunctionInfo(const std::string& name, const std::vector<double>& values, double min, double max) :
name(name), values(values), min(min), max(max), interpolating(interpolating) { name(name), values(values), min(min), max(max) {
} }
}; };
......
...@@ -158,24 +158,23 @@ void CustomGBForce::setExclusionParticles(int index, int particle1, int particle ...@@ -158,24 +158,23 @@ void CustomGBForce::setExclusionParticles(int index, int particle1, int particle
exclusions[index].particle2 = particle2; exclusions[index].particle2 = particle2;
} }
int CustomGBForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) { int CustomGBForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min) if (max <= min)
throw OpenMMException("CustomGBForce: max <= min for a tabulated function."); throw OpenMMException("CustomGBForce: max <= min for a tabulated function.");
if (values.size() < 2) if (values.size() < 2)
throw OpenMMException("CustomGBForce: a tabulated function must have at least two points"); throw OpenMMException("CustomGBForce: a tabulated function must have at least two points");
functions.push_back(FunctionInfo(name, values, min, max, interpolating)); functions.push_back(FunctionInfo(name, values, min, max));
return functions.size()-1; return functions.size()-1;
} }
void CustomGBForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const { void CustomGBForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const {
name = functions[index].name; name = functions[index].name;
values = functions[index].values; values = functions[index].values;
min = functions[index].min; min = functions[index].min;
max = functions[index].max; max = functions[index].max;
interpolating = functions[index].interpolating;
} }
void CustomGBForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) { void CustomGBForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min) if (max <= min)
throw OpenMMException("CustomGBForce: max <= min for a tabulated function."); throw OpenMMException("CustomGBForce: max <= min for a tabulated function.");
if (values.size() < 2) if (values.size() < 2)
...@@ -184,7 +183,6 @@ void CustomGBForce::setFunctionParameters(int index, const std::string& name, co ...@@ -184,7 +183,6 @@ void CustomGBForce::setFunctionParameters(int index, const std::string& name, co
functions[index].values = values; functions[index].values = values;
functions[index].min = min; functions[index].min = min;
functions[index].max = max; functions[index].max = max;
functions[index].interpolating = interpolating;
} }
ForceImpl* CustomGBForce::createImpl() { ForceImpl* CustomGBForce::createImpl() {
......
...@@ -172,24 +172,23 @@ void CustomHbondForce::setExclusionParticles(int index, int donor, int acceptor) ...@@ -172,24 +172,23 @@ void CustomHbondForce::setExclusionParticles(int index, int donor, int acceptor)
exclusions[index].acceptor = acceptor; exclusions[index].acceptor = acceptor;
} }
int CustomHbondForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) { int CustomHbondForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min) if (max <= min)
throw OpenMMException("CustomHbondForce: max <= min for a tabulated function."); throw OpenMMException("CustomHbondForce: max <= min for a tabulated function.");
if (values.size() < 2) if (values.size() < 2)
throw OpenMMException("CustomHbondForce: a tabulated function must have at least two points"); throw OpenMMException("CustomHbondForce: a tabulated function must have at least two points");
functions.push_back(FunctionInfo(name, values, min, max, interpolating)); functions.push_back(FunctionInfo(name, values, min, max));
return functions.size()-1; return functions.size()-1;
} }
void CustomHbondForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const { void CustomHbondForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const {
name = functions[index].name; name = functions[index].name;
values = functions[index].values; values = functions[index].values;
min = functions[index].min; min = functions[index].min;
max = functions[index].max; max = functions[index].max;
interpolating = functions[index].interpolating;
} }
void CustomHbondForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) { void CustomHbondForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min) if (max <= min)
throw OpenMMException("CustomHbondForce: max <= min for a tabulated function."); throw OpenMMException("CustomHbondForce: max <= min for a tabulated function.");
if (values.size() < 2) if (values.size() < 2)
...@@ -198,7 +197,6 @@ void CustomHbondForce::setFunctionParameters(int index, const std::string& name, ...@@ -198,7 +197,6 @@ void CustomHbondForce::setFunctionParameters(int index, const std::string& name,
functions[index].values = values; functions[index].values = values;
functions[index].min = min; functions[index].min = min;
functions[index].max = max; functions[index].max = max;
functions[index].interpolating = interpolating;
} }
ForceImpl* CustomHbondForce::createImpl() { ForceImpl* CustomHbondForce::createImpl() {
......
...@@ -134,24 +134,23 @@ void CustomNonbondedForce::setExclusionParticles(int index, int particle1, int p ...@@ -134,24 +134,23 @@ void CustomNonbondedForce::setExclusionParticles(int index, int particle1, int p
exclusions[index].particle2 = particle2; exclusions[index].particle2 = particle2;
} }
int CustomNonbondedForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) { int CustomNonbondedForce::addFunction(const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min) if (max <= min)
throw OpenMMException("CustomNonbondedForce: max <= min for a tabulated function."); throw OpenMMException("CustomNonbondedForce: max <= min for a tabulated function.");
if (values.size() < 2) if (values.size() < 2)
throw OpenMMException("CustomNonbondedForce: a tabulated function must have at least two points"); throw OpenMMException("CustomNonbondedForce: a tabulated function must have at least two points");
functions.push_back(FunctionInfo(name, values, min, max, interpolating)); functions.push_back(FunctionInfo(name, values, min, max));
return functions.size()-1; return functions.size()-1;
} }
void CustomNonbondedForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max, bool& interpolating) const { void CustomNonbondedForce::getFunctionParameters(int index, std::string& name, std::vector<double>& values, double& min, double& max) const {
name = functions[index].name; name = functions[index].name;
values = functions[index].values; values = functions[index].values;
min = functions[index].min; min = functions[index].min;
max = functions[index].max; max = functions[index].max;
interpolating = functions[index].interpolating;
} }
void CustomNonbondedForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating) { void CustomNonbondedForce::setFunctionParameters(int index, const std::string& name, const std::vector<double>& values, double min, double max) {
if (max <= min) if (max <= min)
throw OpenMMException("CustomNonbondedForce: max <= min for a tabulated function."); throw OpenMMException("CustomNonbondedForce: max <= min for a tabulated function.");
if (values.size() < 2) if (values.size() < 2)
...@@ -160,7 +159,6 @@ void CustomNonbondedForce::setFunctionParameters(int index, const std::string& n ...@@ -160,7 +159,6 @@ void CustomNonbondedForce::setFunctionParameters(int index, const std::string& n
functions[index].values = values; functions[index].values = values;
functions[index].min = min; functions[index].min = min;
functions[index].max = max; functions[index].max = max;
functions[index].interpolating = interpolating;
} }
ForceImpl* CustomNonbondedForce::createImpl() { ForceImpl* CustomNonbondedForce::createImpl() {
......
...@@ -41,9 +41,15 @@ void SplineFitter::createNaturalSpline(const vector<double>& x, const vector<dou ...@@ -41,9 +41,15 @@ void SplineFitter::createNaturalSpline(const vector<double>& x, const vector<dou
int n = x.size(); int n = x.size();
if (y.size() != n) if (y.size() != n)
throw OpenMMException("createNaturalSpline: x and y vectors must have same length"); throw OpenMMException("createNaturalSpline: x and y vectors must have same length");
if (n < 3) if (n < 2)
throw OpenMMException("createNaturalSpline: the length of the input array must be at least 3"); throw OpenMMException("createNaturalSpline: the length of the input array must be at least 2");
deriv.resize(n); deriv.resize(n);
if (n == 2) {
// This is just a straight line.
deriv[0] = 0;
deriv[1] = 0;
}
// Create the system of equations to solve. // Create the system of equations to solve.
......
...@@ -940,9 +940,8 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const ...@@ -940,9 +940,8 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
string name; string name;
vector<double> values; vector<double> values;
double min, max; double min, max;
bool interpolating; force.getFunctionParameters(i, name, values, min, max);
force.getFunctionParameters(i, name, values, min, max, interpolating); gpuSetTabulatedFunction(gpu, i, name, values, min, max);
gpuSetTabulatedFunction(gpu, i, name, values, min, max, interpolating);
} }
// Record information for the expressions. // Record information for the expressions.
......
...@@ -50,6 +50,7 @@ using namespace std; ...@@ -50,6 +50,7 @@ using namespace std;
#include "cudaKernels.h" #include "cudaKernels.h"
#include "hilbert.h" #include "hilbert.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/internal/SplineFitter.h"
#include "quern.h" #include "quern.h"
#include "Lepton.h" #include "Lepton.h"
#include "rng.h" #include "rng.h"
...@@ -614,7 +615,7 @@ void gpuSetNonbondedCutoff(gpuContext gpu, float cutoffDistance, float solventDi ...@@ -614,7 +615,7 @@ void gpuSetNonbondedCutoff(gpuContext gpu, float cutoffDistance, float solventDi
} }
extern "C" extern "C"
void gpuSetTabulatedFunction(gpuContext gpu, int index, const string& name, const vector<double>& values, double min, double max, bool interpolating) void gpuSetTabulatedFunction(gpuContext gpu, int index, const string& name, const vector<double>& values, double min, double max)
{ {
if (index < 0 || index >= MAX_TABULATED_FUNCTIONS) { if (index < 0 || index >= MAX_TABULATED_FUNCTIONS) {
stringstream str; stringstream str;
...@@ -631,32 +632,15 @@ void gpuSetTabulatedFunction(gpuContext gpu, int index, const string& name, cons ...@@ -631,32 +632,15 @@ void gpuSetTabulatedFunction(gpuContext gpu, int index, const string& name, cons
gpu->tabulatedFunctions[index].max = max; gpu->tabulatedFunctions[index].max = max;
gpu->tabulatedFunctionsChanged = true; gpu->tabulatedFunctionsChanged = true;
// First create a padded set of function values. // Compute the spline coefficients.
vector<double> padded(values.size()+2); int numValues = values.size();
padded[0] = 2*values[0]-values[1]; vector<double> x(numValues), derivs;
for (int i = 0; i < (int) values.size(); i++) for (int i = 0; i < numValues; i++)
padded[i+1] = values[i]; x[i] = min+i*(max-min)/(numValues-1);
padded[padded.size()-1] = 2*values[values.size()-1]-values[values.size()-2]; OpenMM::SplineFitter::createNaturalSpline(x, values, derivs);
for (int i = 0; i < (int) values.size()-1; i++)
// Now compute the spline coefficients. (*coeff)[i] = make_float4((float) values[i], (float) values[i+1], (float) (derivs[i]/6.0), (float) (derivs[i+1]/6.0));
for (int i = 0; i < (int) values.size()-1; i++) {
float4 c;
if (interpolating) {
c.x = (float) padded[i+1];
c.y = (float) (0.5*(-padded[i]+padded[i+2]));
c.z = (float) (0.5*(2.0*padded[i]-5.0*padded[i+1]+4.0*padded[i+2]-padded[i+3]));
c.w = (float) (0.5*(-padded[i]+3.0*padded[i+1]-3.0*padded[i+2]+padded[i+3]));
}
else {
c.x = (float) ((padded[i]+4.0*padded[i+1]+padded[i+2])/6.0);
c.y = (float) ((-3.0*padded[i]+3.0*padded[i+2])/6.0);
c.z = (float) ((3.0*padded[i]-6.0*padded[i+1]+3.0*padded[i+2])/6.0);
c.w = (float) ((-padded[i]+3.0*padded[i+1]-3.0*padded[i+2]+padded[i+3])/6.0);
}
(*coeff)[i] = c;
}
coeff->Upload(); coeff->Upload();
} }
...@@ -914,7 +898,7 @@ void gpuSetCustomNonbondedParameters(gpuContext gpu, const vector<vector<double> ...@@ -914,7 +898,7 @@ void gpuSetCustomNonbondedParameters(gpuContext gpu, const vector<vector<double>
for (int i = 0; i < MAX_TABULATED_FUNCTIONS; i++) { for (int i = 0; i < MAX_TABULATED_FUNCTIONS; i++) {
gpuTabulatedFunction& func = gpu->tabulatedFunctions[i]; gpuTabulatedFunction& func = gpu->tabulatedFunctions[i];
if (func.coefficients != NULL) { if (func.coefficients != NULL) {
(*gpu->psTabulatedFunctionParams)[i] = make_float4((float) func.min, (float) func.max, (float) (func.coefficients->_length/(func.max-func.min)), 0.0f); (*gpu->psTabulatedFunctionParams)[i] = make_float4((float) func.min, (float) func.max, (float) (func.coefficients->_length/(func.max-func.min)), (float) (func.coefficients->_length-1));
functions[func.name] = fp; functions[func.name] = fp;
} }
} }
......
...@@ -219,7 +219,7 @@ extern "C" ...@@ -219,7 +219,7 @@ extern "C"
void gpuSetNonbondedCutoff(gpuContext gpu, float cutoffDistance, float solventDielectric); void gpuSetNonbondedCutoff(gpuContext gpu, float cutoffDistance, float solventDielectric);
extern "C" extern "C"
void gpuSetTabulatedFunction(gpuContext gpu, int index, const std::string& name, const std::vector<double>& values, double min, double max, bool interpolating); void gpuSetTabulatedFunction(gpuContext gpu, int index, const std::string& name, const std::vector<double>& values, double min, double max);
extern "C" extern "C"
void gpuSetCustomBondParameters(gpuContext gpu, const std::vector<int>& bondAtom1, const std::vector<int>& bondAtom2, const std::vector<std::vector<double> >& bondParams, void gpuSetCustomBondParameters(gpuContext gpu, const std::vector<int>& bondAtom1, const std::vector<int>& bondAtom2, const std::vector<std::vector<double> >& bondParams,
......
...@@ -97,7 +97,9 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float* ...@@ -97,7 +97,9 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float*
STACK(stackPointer) = 0.0f; STACK(stackPointer) = 0.0f;
else else
{ {
int index = floor((x-params.x)*params.z); x = (x-params.x)*params.z;
int index = floor(x);
index = min(index, (int) params.w);
float4 coeff; float4 coeff;
if (function == 0) if (function == 0)
coeff = tex1Dfetch(texRef0, index); coeff = tex1Dfetch(texRef0, index);
...@@ -107,11 +109,12 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float* ...@@ -107,11 +109,12 @@ __device__ float kEvaluateExpression_kernel(Expression<SIZE>* expression, float*
coeff = tex1Dfetch(texRef2, index); coeff = tex1Dfetch(texRef2, index);
else else
coeff = tex1Dfetch(texRef3, index); coeff = tex1Dfetch(texRef3, index);
x = (x-params.x)*params.z-index; float b = x-index;
float a = 1.0f-b;
if (op == CUSTOM) if (op == CUSTOM)
STACK(stackPointer) = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w)); STACK(stackPointer) = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);
else else
STACK(stackPointer) = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z; STACK(stackPointer) = (coeff.y-coeff.x)*params.z-((3.0f*a*a-1.0f)*coeff.z+(4.0f*b*b-1.0f)*coeff.w)/params.z;
} }
} }
} }
......
...@@ -203,7 +203,7 @@ void testPeriodic() { ...@@ -203,7 +203,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
} }
void testTabulatedFunction(bool interpolating) { void testTabulatedFunction() {
CudaPlatform platform; CudaPlatform platform;
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
...@@ -215,7 +215,7 @@ void testTabulatedFunction(bool interpolating) { ...@@ -215,7 +215,7 @@ void testTabulatedFunction(bool interpolating) {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0, interpolating); forceField->addFunction("fn", table, 1.0, 6.0);
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -233,6 +233,14 @@ void testTabulatedFunction(bool interpolating) { ...@@ -233,6 +233,14 @@ void testTabulatedFunction(bool interpolating) {
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
} }
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
} }
void testCoulombLennardJones() { void testCoulombLennardJones() {
...@@ -327,8 +335,7 @@ int main() { ...@@ -327,8 +335,7 @@ int main() {
testExclusions(); testExclusions();
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
testTabulatedFunction(true); testTabulatedFunction();
testTabulatedFunction(false);
testCoulombLennardJones(); testCoulombLennardJones();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "OpenCLExpressionUtilities.h" #include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/internal/SplineFitter.h"
#include "lepton/Operation.h" #include "lepton/Operation.h"
using namespace OpenMM; using namespace OpenMM;
...@@ -120,14 +121,16 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -120,14 +121,16 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out << "float4 params = " << functionParams << "[" << i << "];\n"; out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n"; out << "if (x >= params.x && x <= params.y) {\n";
out << "int index = (int) (floor((x-params.x)*params.z));\n"; out << "x = (x-params.x)*params.z;\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) params.w);\n"; out << "index = min(index, (int) params.w);\n";
out << "float4 coeff = " << functions[i].second << "[index];\n"; out << "float4 coeff = " << functions[i].second << "[index];\n";
out << "x = (x-params.x)*params.z-index;\n"; out << "float b = x-index;\n";
out << "float a = 1.0f-b;\n";
if (valueNode != NULL) if (valueNode != NULL)
out << valueName << " = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));\n"; out << valueName << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n";
if (derivNode != NULL) if (derivNode != NULL)
out << derivName << " = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;\n"; out << derivName << " = (coeff.y-coeff.x)*params.z-((3.0f*a*a-1.0f)*coeff.z+(4.0f*b*b-1.0f)*coeff.w)/params.z;\n";
out << "}\n"; out << "}\n";
out << "}"; out << "}";
break; break;
...@@ -338,29 +341,16 @@ void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node ...@@ -338,29 +341,16 @@ void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node
findRelatedPowers(node, searchNode.getChildren()[i], powers); findRelatedPowers(node, searchNode.getChildren()[i], powers);
} }
vector<mm_float4> OpenCLExpressionUtilities::computeFunctionCoefficients(const vector<double>& values, bool interpolating) { vector<mm_float4> OpenCLExpressionUtilities::computeFunctionCoefficients(const vector<double>& values, double min, double max) {
// First create a padded set of function values. // Compute the spline coefficients.
vector<double> padded(values.size()+2); int numValues = values.size();
padded[0] = 2*values[0]-values[1]; vector<double> x(numValues), derivs;
for (int i = 0; i < (int) values.size(); i++) for (int i = 0; i < numValues; i++)
padded[i+1] = values[i]; x[i] = min+i*(max-min)/(numValues-1);
padded[padded.size()-1] = 2*values[values.size()-1]-values[values.size()-2]; SplineFitter::createNaturalSpline(x, values, derivs);
vector<mm_float4> f(numValues-1);
// Now compute the spline coefficients. for (int i = 0; i < (int) values.size()-1; i++)
f[i] = mm_float4((cl_float) values[i], (cl_float) values[i+1], (cl_float) (derivs[i]/6.0), (cl_float) (derivs[i+1]/6.0));
vector<mm_float4> f(values.size()-1);
for (int i = 0; i < (int) values.size()-1; i++) {
if (interpolating)
f[i] = mm_float4((cl_float) padded[i+1],
(cl_float) (0.5*(-padded[i]+padded[i+2])),
(cl_float) (0.5*(2.0*padded[i]-5.0*padded[i+1]+4.0*padded[i+2]-padded[i+3])),
(cl_float) (0.5*(-padded[i]+3.0*padded[i+1]-3.0*padded[i+2]+padded[i+3])));
else
f[i] = mm_float4((cl_float) ((padded[i]+4.0*padded[i+1]+padded[i+2])/6.0),
(cl_float) ((-3.0*padded[i]+3.0*padded[i+2])/6.0),
(cl_float) ((3.0*padded[i]-6.0*padded[i+1]+3.0*padded[i+2])/6.0),
(cl_float) ((-padded[i]+3.0*padded[i+1]-3.0*padded[i+2]+padded[i+3])/6.0));
}
return f; return f;
} }
...@@ -60,10 +60,11 @@ public: ...@@ -60,10 +60,11 @@ public:
* Calculate the spline coefficients for a tabulated function that appears in expressions. * Calculate the spline coefficients for a tabulated function that appears in expressions.
* *
* @param values the tabulated values of the function * @param values the tabulated values of the function
* @param interpolating true if an interpolating spline should be used, false if an approximating spline should be used * @param min the value of the independent variable corresponding to the first element of values
* @param max the value of the independent variable corresponding to the last element of values
* @return the spline coefficients * @return the spline coefficients
*/ */
static std::vector<mm_float4> computeFunctionCoefficients(const std::vector<double>& values, bool interpolating); static std::vector<mm_float4> computeFunctionCoefficients(const std::vector<double>& values, double min, double max);
/** /**
* Convert a number to a string in a format suitable for including in a kernel. * Convert a number to a string in a format suitable for including in a kernel.
*/ */
......
...@@ -1533,13 +1533,12 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -1533,13 +1533,12 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
string name; string name;
vector<double> values; vector<double> values;
double min, max; double min, max;
bool interpolating; force.getFunctionParameters(i, name, values, min, max);
force.getFunctionParameters(i, name, values, min, max, interpolating);
string arrayName = prefix+"table"+intToString(i); string arrayName = prefix+"table"+intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = &fp;
tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), values.size()-2); tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), values.size()-2);
vector<mm_float4> f = OpenCLExpressionUtilities::computeFunctionCoefficients(values, interpolating); vector<mm_float4> f = OpenCLExpressionUtilities::computeFunctionCoefficients(values, min, max);
tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction")); tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer())); cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
...@@ -1874,13 +1873,12 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -1874,13 +1873,12 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
string name; string name;
vector<double> values; vector<double> values;
double min, max; double min, max;
bool interpolating; force.getFunctionParameters(i, name, values, min, max);
force.getFunctionParameters(i, name, values, min, max, interpolating);
string arrayName = prefix+"table"+intToString(i); string arrayName = prefix+"table"+intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = &fp;
tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), values.size()-2); tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), values.size()-2);
vector<mm_float4> f = OpenCLExpressionUtilities::computeFunctionCoefficients(values, interpolating); vector<mm_float4> f = OpenCLExpressionUtilities::computeFunctionCoefficients(values, min, max);
tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction")); tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer())); cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
...@@ -2917,13 +2915,12 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -2917,13 +2915,12 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
string name; string name;
vector<double> values; vector<double> values;
double min, max; double min, max;
bool interpolating; force.getFunctionParameters(i, name, values, min, max);
force.getFunctionParameters(i, name, values, min, max, interpolating);
string arrayName = "table"+intToString(i); string arrayName = "table"+intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = &fp;
tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), values.size()-2); tabulatedFunctionParamsVec[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), values.size()-2);
vector<mm_float4> f = OpenCLExpressionUtilities::computeFunctionCoefficients(values, interpolating); vector<mm_float4> f = OpenCLExpressionUtilities::computeFunctionCoefficients(values, min, max);
tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction")); tabulatedFunctions.push_back(new OpenCLArray<mm_float4>(cl, values.size()-1, "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
tableArgs << ", __global float4* " << arrayName; tableArgs << ", __global float4* " << arrayName;
......
...@@ -230,7 +230,7 @@ void testMembrane() { ...@@ -230,7 +230,7 @@ void testMembrane() {
ASSERT_EQUAL_TOL(norm, (state2.getPotentialEnergy()-state.getPotentialEnergy())/stepSize, 1e-2); ASSERT_EQUAL_TOL(norm, (state2.getPotentialEnergy()-state.getPotentialEnergy())/stepSize, 1e-2);
} }
void testTabulatedFunction(bool interpolating) { void testTabulatedFunction() {
OpenCLPlatform platform; OpenCLPlatform platform;
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
...@@ -244,7 +244,7 @@ void testTabulatedFunction(bool interpolating) { ...@@ -244,7 +244,7 @@ void testTabulatedFunction(bool interpolating) {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
force->addFunction("fn", table, 1.0, 6.0, interpolating); force->addFunction("fn", table, 1.0, 6.0);
system.addForce(force); system.addForce(force);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -261,6 +261,14 @@ void testTabulatedFunction(bool interpolating) { ...@@ -261,6 +261,14 @@ void testTabulatedFunction(bool interpolating) {
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
} }
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
} }
void testMultipleChainRules() { void testMultipleChainRules() {
...@@ -424,8 +432,7 @@ int main() { ...@@ -424,8 +432,7 @@ int main() {
testOBC(GBSAOBCForce::CutoffNonPeriodic, CustomGBForce::CutoffNonPeriodic); testOBC(GBSAOBCForce::CutoffNonPeriodic, CustomGBForce::CutoffNonPeriodic);
testOBC(GBSAOBCForce::CutoffPeriodic, CustomGBForce::CutoffPeriodic); testOBC(GBSAOBCForce::CutoffPeriodic, CustomGBForce::CutoffPeriodic);
testMembrane(); testMembrane();
testTabulatedFunction(true); testTabulatedFunction();
testTabulatedFunction(false);
testMultipleChainRules(); testMultipleChainRules();
testPositionDependence(); testPositionDependence();
testExclusions(); testExclusions();
......
...@@ -195,7 +195,7 @@ void testCustomFunctions() { ...@@ -195,7 +195,7 @@ void testCustomFunctions() {
vector<double> function(2); vector<double> function(2);
function[0] = 0; function[0] = 0;
function[1] = 1; function[1] = 1;
custom->addFunction("foo", function, 0, 10, true); custom->addFunction("foo", function, 0, 10);
system.addForce(custom); system.addForce(custom);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(3); vector<Vec3> positions(3);
......
...@@ -242,7 +242,7 @@ void testPeriodic() { ...@@ -242,7 +242,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
} }
void testTabulatedFunction(bool interpolating) { void testTabulatedFunction() {
OpenCLPlatform platform; OpenCLPlatform platform;
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
...@@ -254,7 +254,7 @@ void testTabulatedFunction(bool interpolating) { ...@@ -254,7 +254,7 @@ void testTabulatedFunction(bool interpolating) {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0, interpolating); forceField->addFunction("fn", table, 1.0, 6.0);
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -272,6 +272,14 @@ void testTabulatedFunction(bool interpolating) { ...@@ -272,6 +272,14 @@ void testTabulatedFunction(bool interpolating) {
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
} }
for (int i = 1; i < 20; i++) {
double x = 0.25*i+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
} }
void testCoulombLennardJones() { void testCoulombLennardJones() {
...@@ -357,8 +365,7 @@ int main() { ...@@ -357,8 +365,7 @@ int main() {
testExclusions(); testExclusions();
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
testTabulatedFunction(true); testTabulatedFunction();
testTabulatedFunction(false);
testCoulombLennardJones(); testCoulombLennardJones();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -63,6 +63,7 @@ ...@@ -63,6 +63,7 @@
#include "openmm/internal/CustomHbondForceImpl.h" #include "openmm/internal/CustomHbondForceImpl.h"
#include "openmm/internal/CMAPTorsionForceImpl.h" #include "openmm/internal/CMAPTorsionForceImpl.h"
#include "openmm/internal/NonbondedForceImpl.h" #include "openmm/internal/NonbondedForceImpl.h"
#include "openmm/internal/SplineFitter.h"
#include "openmm/Integrator.h" #include "openmm/Integrator.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "SimTKUtilities/SimTKOpenMMUtilities.h" #include "SimTKUtilities/SimTKOpenMMUtilities.h"
...@@ -713,61 +714,34 @@ double ReferenceCalcNonbondedForceKernel::execute(ContextImpl& context, bool inc ...@@ -713,61 +714,34 @@ double ReferenceCalcNonbondedForceKernel::execute(ContextImpl& context, bool inc
class ReferenceTabulatedFunction : public Lepton::CustomFunction { class ReferenceTabulatedFunction : public Lepton::CustomFunction {
public: public:
ReferenceTabulatedFunction(double min, double max, const vector<double>& values, bool interpolating) : ReferenceTabulatedFunction(double min, double max, const vector<double>& values) :
min(min), max(max), values(values), interpolating(interpolating) { min(min), max(max), values(values) {
int numValues = values.size();
x.resize(numValues);
for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createNaturalSpline(x, values, derivs);
} }
int getNumArguments() const { int getNumArguments() const {
return 1; return 1;
} }
/**
* Given the function argument, find the local spline coefficients.
*/
void findCoefficients(double& x, double* coeff) const {
int length = values.size();
double scale = (length-1)/(max-min);
int index = (int) std::floor((x-min)*scale);
double points[4];
points[0] = (index == 0 ? 2*values[0]-values[1] : values[index-1]);
points[1] = values[index];
points[2] = (index > length-2 ? values[length-1] : values[index+1]);
points[3] = (index > length-3 ? 2*values[length-1]-values[length-2] : values[index+2]);
if (interpolating) {
coeff[0] = points[1];
coeff[1] = 0.5*(-points[0]+points[2]);
coeff[2] = 0.5*(2.0*points[0]-5.0*points[1]+4.0*points[2]-points[3]);
coeff[3] = 0.5*(-points[0]+3.0*points[1]-3.0*points[2]+points[3]);
}
else {
coeff[0] = (points[0]+4.0*points[1]+points[2])/6.0;
coeff[1] = (-3.0*points[0]+3.0*points[2])/6.0;
coeff[2] = (3.0*points[0]-6.0*points[1]+3.0*points[2])/6.0;
coeff[3] = (-points[0]+3.0*points[1]-3.0*points[2]+points[3])/6.0;
}
x = (x-min)*scale-index;
}
double evaluate(const double* arguments) const { double evaluate(const double* arguments) const {
double x = arguments[0]; double t = arguments[0];
if (x < min || x > max) if (t < min || t > max)
return 0.0; return 0.0;
double coeff[4]; return SplineFitter::evaluateSpline(x, values, derivs, t);
findCoefficients(x, coeff);
return coeff[0]+x*(coeff[1]+x*(coeff[2]+x*coeff[3]));
} }
double evaluateDerivative(const double* arguments, const int* derivOrder) const { double evaluateDerivative(const double* arguments, const int* derivOrder) const {
double x = arguments[0]; double t = arguments[0];
if (x < min || x > max) if (t < min || t > max)
return 0.0; return 0.0;
double coeff[4]; return SplineFitter::evaluateSplineDerivative(x, values, derivs, t);
findCoefficients(x, coeff);
double scale = (values.size()-1)/(max-min);
return scale*(coeff[1]+x*(2.0*coeff[2]+x*3.0*coeff[3])); // We assume a first derivative, because that's the only order ever used by CustomNonbondedForce.
} }
CustomFunction* clone() const { CustomFunction* clone() const {
return new ReferenceTabulatedFunction(min, max, values, interpolating); return new ReferenceTabulatedFunction(min, max, values);
} }
double min, max; double min, max;
vector<double> values; vector<double> x, values, derivs;
bool interpolating;
}; };
ReferenceCalcCustomNonbondedForceKernel::~ReferenceCalcCustomNonbondedForceKernel() { ReferenceCalcCustomNonbondedForceKernel::~ReferenceCalcCustomNonbondedForceKernel() {
...@@ -822,9 +796,8 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c ...@@ -822,9 +796,8 @@ void ReferenceCalcCustomNonbondedForceKernel::initialize(const System& system, c
string name; string name;
vector<double> values; vector<double> values;
double min, max; double min, max;
bool interpolating; force.getFunctionParameters(i, name, values, min, max);
force.getFunctionParameters(i, name, values, min, max, interpolating); functions[name] = new ReferenceTabulatedFunction(min, max, values);
functions[name] = new ReferenceTabulatedFunction(min, max, values, interpolating);
} }
// Parse the various expressions used to calculate the force. // Parse the various expressions used to calculate the force.
...@@ -1010,9 +983,8 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu ...@@ -1010,9 +983,8 @@ void ReferenceCalcCustomGBForceKernel::initialize(const System& system, const Cu
string name; string name;
vector<double> values; vector<double> values;
double min, max; double min, max;
bool interpolating; force.getFunctionParameters(i, name, values, min, max);
force.getFunctionParameters(i, name, values, min, max, interpolating); functions[name] = new ReferenceTabulatedFunction(min, max, values);
functions[name] = new ReferenceTabulatedFunction(min, max, values, interpolating);
} }
// Parse the expressions for computed values. // Parse the expressions for computed values.
...@@ -1204,9 +1176,8 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const ...@@ -1204,9 +1176,8 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
string name; string name;
vector<double> values; vector<double> values;
double min, max; double min, max;
bool interpolating; force.getFunctionParameters(i, name, values, min, max);
force.getFunctionParameters(i, name, values, min, max, interpolating); functions[name] = new ReferenceTabulatedFunction(min, max, values);
functions[name] = new ReferenceTabulatedFunction(min, max, values, interpolating);
} }
// Parse the expression and create the object used to calculate the interaction. // Parse the expression and create the object used to calculate the interaction.
......
...@@ -232,7 +232,7 @@ void testMembrane() { ...@@ -232,7 +232,7 @@ void testMembrane() {
ASSERT_EQUAL_TOL(norm, (state2.getPotentialEnergy()-state.getPotentialEnergy())/stepSize, 1e-2); ASSERT_EQUAL_TOL(norm, (state2.getPotentialEnergy()-state.getPotentialEnergy())/stepSize, 1e-2);
} }
void testTabulatedFunction(bool interpolating) { void testTabulatedFunction() {
ReferencePlatform platform; ReferencePlatform platform;
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
...@@ -246,7 +246,7 @@ void testTabulatedFunction(bool interpolating) { ...@@ -246,7 +246,7 @@ void testTabulatedFunction(bool interpolating) {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
force->addFunction("fn", table, 1.0, 6.0, interpolating); force->addFunction("fn", table, 1.0, 6.0);
system.addForce(force); system.addForce(force);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -865,8 +865,7 @@ int main() { ...@@ -865,8 +865,7 @@ int main() {
testOBC(GBSAOBCForce::CutoffNonPeriodic, CustomGBForce::CutoffNonPeriodic); testOBC(GBSAOBCForce::CutoffNonPeriodic, CustomGBForce::CutoffNonPeriodic);
testOBC(GBSAOBCForce::CutoffPeriodic, CustomGBForce::CutoffPeriodic); testOBC(GBSAOBCForce::CutoffPeriodic, CustomGBForce::CutoffPeriodic);
testMembrane(); testMembrane();
testTabulatedFunction(true); testTabulatedFunction();
testTabulatedFunction(false);
testMultipleChainRules(); testMultipleChainRules();
testPositionDependence(); testPositionDependence();
testExclusions(); testExclusions();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment