Unverified Commit eae83041 authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2611 from craabreu/periodic_variables_metadynamics

[WIP] Changes in Metadynamics
parents eec9cd69 691aa78c
...@@ -63,6 +63,13 @@ public: ...@@ -63,6 +63,13 @@ public:
* @deprecated This will be removed in a future release. * @deprecated This will be removed in a future release.
*/ */
virtual TabulatedFunction* Copy() const = 0; virtual TabulatedFunction* Copy() const = 0;
/**
* Get the periodicity status of the tabulated function.
*
*/
bool getPeriodic() const;
protected:
bool periodic;
}; };
/** /**
...@@ -78,8 +85,9 @@ public: ...@@ -78,8 +85,9 @@ public:
* The function is assumed to be zero for x < min or x > max. * The function is assumed to be zero for x < min or x > max.
* @param min the value of x corresponding to the first element of values * @param min the value of x corresponding to the first element of values
* @param max the value of x corresponding to the last element of values * @param max the value of x corresponding to the last element of values
* @param periodic whether the interpolated function is periodic
*/ */
Continuous1DFunction(const std::vector<double>& values, double min, double max); Continuous1DFunction(const std::vector<double>& values, double min, double max, bool periodic=false);
/** /**
* Get the parameters for the tabulated function. * Get the parameters for the tabulated function.
* *
...@@ -129,8 +137,9 @@ public: ...@@ -129,8 +137,9 @@ public:
* @param xmax the value of x corresponding to the last element of values * @param xmax the value of x corresponding to the last element of values
* @param ymin the value of y corresponding to the first element of values * @param ymin the value of y corresponding to the first element of values
* @param ymax the value of y corresponding to the last element of values * @param ymax the value of y corresponding to the last element of values
* @param periodic whether the interpolated function is periodic
*/ */
Continuous2DFunction(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax); Continuous2DFunction(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, bool periodic=false);
/** /**
* Get the parameters for the tabulated function. * Get the parameters for the tabulated function.
* *
...@@ -196,8 +205,9 @@ public: ...@@ -196,8 +205,9 @@ public:
* @param ymax the value of y corresponding to the last element of values * @param ymax the value of y corresponding to the last element of values
* @param zmin the value of z corresponding to the first element of values * @param zmin the value of z corresponding to the first element of values
* @param zmax the value of z corresponding to the last element of values * @param zmax the value of z corresponding to the last element of values
* @param periodic whether the interpolated function is periodic
*/ */
Continuous3DFunction(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax); Continuous3DFunction(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax, bool periodic=false);
/** /**
* Get the parameters for the tabulated function. * Get the parameters for the tabulated function.
* *
......
...@@ -43,6 +43,18 @@ namespace OpenMM { ...@@ -43,6 +43,18 @@ namespace OpenMM {
class OPENMM_EXPORT SplineFitter { class OPENMM_EXPORT SplineFitter {
public: public:
/**
* Fit a cubic spline to a set of data points. The resulting spline interpolates all the
* data points and has a continuous second derivative everywhere. The second derivatives are
* identical at the end points if periodic=true or 0 at the end points if periodic=false.
*
* @param x the values of the independent variable at the data points to interpolate. They must
* be strictly increasing: x[i] > x[i-1].
* @param y the values of the dependent variable at the data points to interpolate
* @param periodic whether the interpolated function is periodic
* @param deriv on exit, this contains the second derivative of the spline at each of the data points
*/
static void createSpline(const std::vector<double>& x, const std::vector<double>& y, bool periodic, std::vector<double>& deriv);
/** /**
* Fit a natural cubic spline to a set of data points. The resulting spline interpolates all the * Fit a natural cubic spline to a set of data points. The resulting spline interpolates all the
* data points, has a continuous second derivative everywhere, and has a second derivative of 0 at * data points, has a continuous second derivative everywhere, and has a second derivative of 0 at
...@@ -86,6 +98,21 @@ public: ...@@ -86,6 +98,21 @@ public:
* @return the value of the spline's derivative at the specified point * @return the value of the spline's derivative at the specified point
*/ */
static double evaluateSplineDerivative(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& deriv, double t); static double evaluateSplineDerivative(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& deriv, double t);
/**
* Fit a cubic spline surface f(x,y) to a 2D set of data points. The resulting spline interpolates all the
* data points and has a continuous second derivative everywhere. The second derivatives are identical at
* the boundary if periodic=true or 0 at the boundary if periodic=false.
*
* @param x the values of the first independent variable at the data points to interpolate. They must
* be strictly increasing: x[i] > x[i-1].
* @param y the values of the second independent variable at the data points to interpolate. They must
* be strictly increasing: y[i] > y[i-1].
* @param values the values of the dependent variable at the data points to interpolate. They must be ordered
* so that values[i+xsize*j] = f(x[i],y[j]), where xsize is the length of x.
* @param periodic whether the interpolated function is periodic
* @param c on exit, this contains the spline coefficients at each of the data points
*/
static void create2DSpline(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& values, bool periodic, std::vector<std::vector<double> >& c);
/** /**
* Fit a natural cubic spline surface f(x,y) to a 2D set of data points. The resulting spline interpolates all the * Fit a natural cubic spline surface f(x,y) to a 2D set of data points. The resulting spline interpolates all the
* data points, has a continuous second derivative everywhere, and has a second derivative of 0 at the boundary. * data points, has a continuous second derivative everywhere, and has a second derivative of 0 at the boundary.
...@@ -124,6 +151,24 @@ public: ...@@ -124,6 +151,24 @@ public:
* @param dy on exit, the y derivative of the spline at the specified point * @param dy on exit, the y derivative of the spline at the specified point
*/ */
static void evaluate2DSplineDerivatives(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& values, const std::vector<std::vector<double> >& c, double u, double v, double& dx, double& dy); static void evaluate2DSplineDerivatives(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& values, const std::vector<std::vector<double> >& c, double u, double v, double& dx, double& dy);
/**
* Fit a cubic spline surface f(x,y,z) to a 3D set of data points. The resulting spline interpolates all the
* data points and has a continuous second derivative everywhere. The second derivatives are identical at
* the boundary if periodic=true or 0 at the boundary if periodic=false.
*
* @param x the values of the first independent variable at the data points to interpolate. They must
* be strictly increasing: x[i] > x[i-1].
* @param y the values of the second independent variable at the data points to interpolate. They must
* be strictly increasing: y[i] > y[i-1].
* @param z the values of the third independent variable at the data points to interpolate. They must
* be strictly increasing: z[i] > z[i-1].
* @param values the values of the dependent variable at the data points to interpolate. They must be ordered
* so that values[i+xsize*j+xsize*ysize*k] = f(x[i],y[j],z[k]), where xsize is the length of x
* and ysize is the length of y.
* @param periodic whether the interpolated function is periodic
* @param c on exit, this contains the spline coefficients at each of the data points
*/
static void create3DSpline(const std::vector<double>& x, const std::vector<double>& y, const std::vector<double>& z, const std::vector<double>& values, bool periodic, std::vector<std::vector<double> >& c);
/** /**
* Fit a natural cubic spline surface f(x,y,z) to a 3D set of data points. The resulting spline interpolates all the * Fit a natural cubic spline surface f(x,y,z) to a 3D set of data points. The resulting spline interpolates all the
* data points, has a continuous second derivative everywhere, and has a second derivative of 0 at the boundary. * data points, has a continuous second derivative everywhere, and has a second derivative of 0 at the boundary.
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include <vector> #include <vector>
#include <math.h>
#include "openmm/internal/SplineFitter.h" #include "openmm/internal/SplineFitter.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
...@@ -37,6 +38,17 @@ ...@@ -37,6 +38,17 @@
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
static bool notEqual(double a, double b) {
return (fabs(a-b) > 1e-15 + 1e-15*fabs(b));
}
void SplineFitter::createSpline(const vector<double>& x, const vector<double>& y, bool periodic, vector<double>& deriv) {
if (periodic)
SplineFitter::createPeriodicSpline(x, y, deriv);
else
SplineFitter::createNaturalSpline(x, y, deriv);
}
void SplineFitter::createNaturalSpline(const vector<double>& x, const vector<double>& y, vector<double>& deriv) { void SplineFitter::createNaturalSpline(const vector<double>& x, const vector<double>& y, vector<double>& deriv) {
int n = x.size(); int n = x.size();
if (y.size() != n) if (y.size() != n)
...@@ -80,7 +92,7 @@ void SplineFitter::createPeriodicSpline(const vector<double>& x, const vector<do ...@@ -80,7 +92,7 @@ void SplineFitter::createPeriodicSpline(const vector<double>& x, const vector<do
throw OpenMMException("createPeriodicSpline: x and y vectors must have same length"); throw OpenMMException("createPeriodicSpline: x and y vectors must have same length");
if (n < 3) if (n < 3)
throw OpenMMException("createPeriodicSpline: the length of the input array must be at least 3"); throw OpenMMException("createPeriodicSpline: the length of the input array must be at least 3");
if (y[0] != y[n-1]) if (notEqual(y[0], y[n-1]))
throw OpenMMException("createPeriodicSpline: the first and last points must have the same value"); throw OpenMMException("createPeriodicSpline: the first and last points must have the same value");
deriv.resize(n); deriv.resize(n);
...@@ -183,15 +195,19 @@ void SplineFitter::solveTridiagonalMatrix(const vector<double>& a, const vector< ...@@ -183,15 +195,19 @@ void SplineFitter::solveTridiagonalMatrix(const vector<double>& a, const vector<
sol[i] = (rhs[i]-a[i]*sol[i-1])/beta; sol[i] = (rhs[i]-a[i]*sol[i-1])/beta;
} }
// Perform backsubstitation. // Perform backsubstitution.
for (int i = n-2; i >= 0; i--) for (int i = n-2; i >= 0; i--)
sol[i] -= gamma[i+1]*sol[i+1]; sol[i] -= gamma[i+1]*sol[i+1];
} }
void SplineFitter::create2DNaturalSpline(const vector<double>& x, const vector<double>& y, const vector<double>& values, vector<vector<double> >& c) { void SplineFitter::create2DSpline(const vector<double>& x, const vector<double>& y, const vector<double>& values, bool periodic, vector<vector<double> >& c) {
int xsize = x.size(), ysize = y.size(); int xsize = x.size(), ysize = y.size();
if (xsize < 2 || ysize < 2) if (periodic) {
if (xsize < 3 || ysize < 3)
throw OpenMMException("create2DNaturalSpline: periodic spline must have at least three points along each axis");
}
else if (xsize < 2 || ysize < 2)
throw OpenMMException("create2DNaturalSpline: must have at least two points along each axis"); throw OpenMMException("create2DNaturalSpline: must have at least two points along each axis");
if (values.size() != xsize*ysize) if (values.size() != xsize*ysize)
throw OpenMMException("create2DNaturalSpline: incorrect number of values"); throw OpenMMException("create2DNaturalSpline: incorrect number of values");
...@@ -203,7 +219,7 @@ void SplineFitter::create2DNaturalSpline(const vector<double>& x, const vector<d ...@@ -203,7 +219,7 @@ void SplineFitter::create2DNaturalSpline(const vector<double>& x, const vector<d
for (int i = 0; i < ysize; i++) { for (int i = 0; i < ysize; i++) {
for (int j = 0; j < xsize; j++) for (int j = 0; j < xsize; j++)
t[j] = values[j+xsize*i]; t[j] = values[j+xsize*i];
SplineFitter::createNaturalSpline(x, t, deriv); SplineFitter::createSpline(x, t, periodic, deriv);
for (int j = 0; j < xsize; j++) for (int j = 0; j < xsize; j++)
d1[j+xsize*i] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[j]); d1[j+xsize*i] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[j]);
} }
...@@ -215,7 +231,7 @@ void SplineFitter::create2DNaturalSpline(const vector<double>& x, const vector<d ...@@ -215,7 +231,7 @@ void SplineFitter::create2DNaturalSpline(const vector<double>& x, const vector<d
for (int i = 0; i < xsize; i++) { for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) for (int j = 0; j < ysize; j++)
t[j] = values[i+xsize*j]; t[j] = values[i+xsize*j];
SplineFitter::createNaturalSpline(y, t, deriv); SplineFitter::createSpline(y, t, periodic, deriv);
for (int j = 0; j < ysize; j++) for (int j = 0; j < ysize; j++)
d2[i+xsize*j] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[j]); d2[i+xsize*j] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[j]);
} }
...@@ -227,7 +243,7 @@ void SplineFitter::create2DNaturalSpline(const vector<double>& x, const vector<d ...@@ -227,7 +243,7 @@ void SplineFitter::create2DNaturalSpline(const vector<double>& x, const vector<d
for (int i = 0; i < ysize; i++) { for (int i = 0; i < ysize; i++) {
for (int j = 0; j < xsize; j++) for (int j = 0; j < xsize; j++)
t[j] = d2[j+xsize*i]; t[j] = d2[j+xsize*i];
SplineFitter::createNaturalSpline(x, t, deriv); SplineFitter::createSpline(x, t, periodic, deriv);
for (int j = 0; j < xsize; j++) for (int j = 0; j < xsize; j++)
d12[j+xsize*i] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[j]); d12[j+xsize*i] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[j]);
} }
...@@ -285,6 +301,10 @@ void SplineFitter::create2DNaturalSpline(const vector<double>& x, const vector<d ...@@ -285,6 +301,10 @@ void SplineFitter::create2DNaturalSpline(const vector<double>& x, const vector<d
} }
} }
void SplineFitter::create2DNaturalSpline(const vector<double>& x, const vector<double>& y, const vector<double>& values, vector<vector<double> >& c) {
SplineFitter::create2DSpline(x, y, values, false, c);
}
double SplineFitter::evaluate2DSpline(const vector<double>& x, const vector<double>& y, const vector<double>& values, const vector<vector<double> >& c, double u, double v) { double SplineFitter::evaluate2DSpline(const vector<double>& x, const vector<double>& y, const vector<double>& values, const vector<vector<double> >& c, double u, double v) {
int xsize = x.size(); int xsize = x.size();
int ysize = y.size(); int ysize = y.size();
...@@ -369,11 +389,15 @@ void SplineFitter::evaluate2DSplineDerivatives(const vector<double>& x, const ve ...@@ -369,11 +389,15 @@ void SplineFitter::evaluate2DSplineDerivatives(const vector<double>& x, const ve
dy /= deltay; dy /= deltay;
} }
void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, vector<vector<double> >& c) { void SplineFitter::create3DSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, bool periodic, vector<vector<double> >& c) {
int xsize = x.size(), ysize = y.size(), zsize = z.size(); int xsize = x.size(), ysize = y.size(), zsize = z.size();
int xysize = xsize*ysize; int xysize = xsize*ysize;
if (xsize < 2 || ysize < 2 || zsize < 2) if (periodic) {
throw OpenMMException("create2DNaturalSpline: must have at least two points along each axis"); if (xsize < 3 || ysize < 3 || zsize < 3)
throw OpenMMException("create3DNaturalSpline: periodic spline must have at least three points along each axis");
}
else if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("create3DNaturalSpline: must have at least two points along each axis");
if (values.size() != xsize*ysize*zsize) if (values.size() != xsize*ysize*zsize)
throw OpenMMException("create2DNaturalSpline: incorrect number of values"); throw OpenMMException("create2DNaturalSpline: incorrect number of values");
vector<double> d1(xsize*ysize*zsize), d2(xsize*ysize*zsize), d3(xsize*ysize*zsize); vector<double> d1(xsize*ysize*zsize), d2(xsize*ysize*zsize), d3(xsize*ysize*zsize);
...@@ -386,7 +410,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d ...@@ -386,7 +410,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d
for (int j = 0; j < zsize; j++) { for (int j = 0; j < zsize; j++) {
for (int k = 0; k < xsize; k++) for (int k = 0; k < xsize; k++)
t[k] = values[k+xsize*i+xysize*j]; t[k] = values[k+xsize*i+xysize*j];
SplineFitter::createNaturalSpline(x, t, deriv); SplineFitter::createSpline(x, t, periodic, deriv);
for (int k = 0; k < xsize; k++) for (int k = 0; k < xsize; k++)
d1[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]); d1[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]);
} }
...@@ -400,7 +424,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d ...@@ -400,7 +424,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d
for (int j = 0; j < zsize; j++) { for (int j = 0; j < zsize; j++) {
for (int k = 0; k < ysize; k++) for (int k = 0; k < ysize; k++)
t[k] = values[i+xsize*k+xysize*j]; t[k] = values[i+xsize*k+xysize*j];
SplineFitter::createNaturalSpline(y, t, deriv); SplineFitter::createSpline(y, t, periodic, deriv);
for (int k = 0; k < ysize; k++) for (int k = 0; k < ysize; k++)
d2[i+xsize*k+xysize*j] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[k]); d2[i+xsize*k+xysize*j] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[k]);
} }
...@@ -414,7 +438,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d ...@@ -414,7 +438,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d
for (int j = 0; j < ysize; j++) { for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) for (int k = 0; k < zsize; k++)
t[k] = values[i+xsize*j+xysize*k]; t[k] = values[i+xsize*j+xysize*k];
SplineFitter::createNaturalSpline(z, t, deriv); SplineFitter::createSpline(z, t, periodic, deriv);
for (int k = 0; k < zsize; k++) for (int k = 0; k < zsize; k++)
d3[i+xsize*j+xysize*k] = SplineFitter::evaluateSplineDerivative(z, t, deriv, z[k]); d3[i+xsize*j+xysize*k] = SplineFitter::evaluateSplineDerivative(z, t, deriv, z[k]);
} }
...@@ -428,7 +452,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d ...@@ -428,7 +452,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d
for (int j = 0; j < zsize; j++) { for (int j = 0; j < zsize; j++) {
for (int k = 0; k < xsize; k++) for (int k = 0; k < xsize; k++)
t[k] = d2[k+xsize*i+xysize*j]; t[k] = d2[k+xsize*i+xysize*j];
SplineFitter::createNaturalSpline(x, t, deriv); SplineFitter::createSpline(x, t, periodic, deriv);
for (int k = 0; k < xsize; k++) for (int k = 0; k < xsize; k++)
d12[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]); d12[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]);
} }
...@@ -442,7 +466,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d ...@@ -442,7 +466,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d
for (int j = 0; j < xsize; j++) { for (int j = 0; j < xsize; j++) {
for (int k = 0; k < ysize; k++) for (int k = 0; k < ysize; k++)
t[k] = d3[j+xsize*k+xysize*i]; t[k] = d3[j+xsize*k+xysize*i];
SplineFitter::createNaturalSpline(y, t, deriv); SplineFitter::createSpline(y, t, periodic, deriv);
for (int k = 0; k < ysize; k++) for (int k = 0; k < ysize; k++)
d23[j+xsize*k+xysize*i] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[k]); d23[j+xsize*k+xysize*i] = SplineFitter::evaluateSplineDerivative(y, t, deriv, y[k]);
} }
...@@ -456,7 +480,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d ...@@ -456,7 +480,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d
for (int j = 0; j < ysize; j++) { for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) for (int k = 0; k < zsize; k++)
t[k] = d1[i+xsize*j+xysize*k]; t[k] = d1[i+xsize*j+xysize*k];
SplineFitter::createNaturalSpline(z, t, deriv); SplineFitter::createSpline(z, t, periodic, deriv);
for (int k = 0; k < zsize; k++) for (int k = 0; k < zsize; k++)
d13[i+xsize*j+xysize*k] = SplineFitter::evaluateSplineDerivative(z, t, deriv, z[k]); d13[i+xsize*j+xysize*k] = SplineFitter::evaluateSplineDerivative(z, t, deriv, z[k]);
} }
...@@ -470,7 +494,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d ...@@ -470,7 +494,7 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d
for (int j = 0; j < zsize; j++) { for (int j = 0; j < zsize; j++) {
for (int k = 0; k < xsize; k++) for (int k = 0; k < xsize; k++)
t[k] = d23[k+xsize*i+xysize*j]; t[k] = d23[k+xsize*i+xysize*j];
SplineFitter::createNaturalSpline(x, t, deriv); SplineFitter::createSpline(x, t, periodic, deriv);
for (int k = 0; k < xsize; k++) for (int k = 0; k < xsize; k++)
d123[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]); d123[k+xsize*i+xysize*j] = SplineFitter::evaluateSplineDerivative(x, t, deriv, x[k]);
} }
...@@ -599,6 +623,10 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d ...@@ -599,6 +623,10 @@ void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<d
} }
} }
void SplineFitter::create3DNaturalSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, vector<vector<double> >& c) {
SplineFitter::create3DSpline(x, y, z, values, false, c);
}
double SplineFitter::evaluate3DSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, const vector<vector<double> >& c, double u, double v, double w) { double SplineFitter::evaluate3DSpline(const vector<double>& x, const vector<double>& y, const vector<double>& z, const vector<double>& values, const vector<vector<double> >& c, double u, double v, double w) {
int xsize = x.size(); int xsize = x.size();
int ysize = y.size(); int ysize = y.size();
......
...@@ -35,14 +35,13 @@ ...@@ -35,14 +35,13 @@
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max) { bool TabulatedFunction::getPeriodic() const {
if (max <= min) return periodic;
throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function."); }
if (values.size() < 2)
throw OpenMMException("Continuous1DFunction: a tabulated function must have at least two points"); Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max, bool periodic) {
this->values = values; this->periodic = periodic;
this->min = min; setFunctionParameters(values, min, max);
this->max = max;
} }
void Continuous1DFunction::getFunctionParameters(vector<double>& values, double& min, double& max) const { void Continuous1DFunction::getFunctionParameters(vector<double>& values, double& min, double& max) const {
...@@ -54,8 +53,16 @@ void Continuous1DFunction::getFunctionParameters(vector<double>& values, double& ...@@ -54,8 +53,16 @@ void Continuous1DFunction::getFunctionParameters(vector<double>& values, double&
void Continuous1DFunction::setFunctionParameters(const vector<double>& values, double min, double max) { void Continuous1DFunction::setFunctionParameters(const vector<double>& values, double min, double max) {
if (max <= min) if (max <= min)
throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function."); throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function.");
if (values.size() < 2) int n = values.size();
throw OpenMMException("Continuous1DFunction: a tabulated function must have at least two points"); if (periodic) {
if (n < 3)
throw OpenMMException("Continuous1DFunction: a periodic tabulated function must have at least three points");
// Note: value-matching at boundary is eventually checked at spline creation.
// if (values[0] != values[n-1])
// throw OpenMMException("Continuous1DFunction: with periodic=true, the first and last points must have the same value");
}
else if (n < 2)
throw OpenMMException("Continuous1DFunction: a non-periodic tabulated function must have at least two points");
this->values = values; this->values = values;
this->min = min; this->min = min;
this->max = max; this->max = max;
...@@ -68,22 +75,9 @@ Continuous1DFunction* Continuous1DFunction::Copy() const { ...@@ -68,22 +75,9 @@ Continuous1DFunction* Continuous1DFunction::Copy() const {
return new Continuous1DFunction(new_vec, min, max); return new Continuous1DFunction(new_vec, min, max);
} }
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) { Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, bool periodic) {
if (xsize < 2 || ysize < 2) this->periodic = periodic;
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis"); setFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
if (values.size() != xsize*ysize)
throw OpenMMException("Continuous2DFunction: incorrect number of values");
if (xmax <= xmin)
throw OpenMMException("Continuous2DFunction: xmax <= xmin for a tabulated function.");
if (ymax <= ymin)
throw OpenMMException("Continuous2DFunction: ymax <= ymin for a tabulated function.");
this->values = values;
this->xsize = xsize;
this->ysize = ysize;
this->xmin = xmin;
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
} }
void Continuous2DFunction::getFunctionParameters(int& xsize, int& ysize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax) const { void Continuous2DFunction::getFunctionParameters(int& xsize, int& ysize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax) const {
...@@ -97,7 +91,12 @@ void Continuous2DFunction::getFunctionParameters(int& xsize, int& ysize, vector< ...@@ -97,7 +91,12 @@ void Continuous2DFunction::getFunctionParameters(int& xsize, int& ysize, vector<
} }
void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) { void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) {
if (xsize < 2 || ysize < 2) if (periodic) {
if (xsize < 3 || ysize < 3)
throw OpenMMException("Continuous2DFunction: must have at least three points along each axis if periodic");
// Note: value-matching at boundary is eventually checked at 2D-spline creation.
}
else if (xsize < 2 || ysize < 2)
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis"); throw OpenMMException("Continuous2DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize) if (values.size() != xsize*ysize)
throw OpenMMException("Continuous2DFunction: incorrect number of values"); throw OpenMMException("Continuous2DFunction: incorrect number of values");
...@@ -121,27 +120,9 @@ Continuous2DFunction* Continuous2DFunction::Copy() const { ...@@ -121,27 +120,9 @@ Continuous2DFunction* Continuous2DFunction::Copy() const {
return new Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax); return new Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax);
} }
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) { Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax, bool periodic) {
if (xsize < 2 || ysize < 2 || zsize < 2) this->periodic = periodic;
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis"); setFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Continuous3DFunction: incorrect number of values");
if (xmax <= xmin)
throw OpenMMException("Continuous3DFunction: xmax <= xmin for a tabulated function.");
if (ymax <= ymin)
throw OpenMMException("Continuous3DFunction: ymax <= ymin for a tabulated function.");
if (zmax <= zmin)
throw OpenMMException("Continuous3DFunction: zmax <= zmin for a tabulated function.");
this->values = values;
this->xsize = xsize;
this->ysize = ysize;
this->zsize = zsize;
this->xmin = xmin;
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
this->zmin = zmin;
this->zmax = zmax;
} }
void Continuous3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zsize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax, double& zmin, double& zmax) const { void Continuous3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zsize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax, double& zmin, double& zmax) const {
...@@ -158,7 +139,12 @@ void Continuous3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zs ...@@ -158,7 +139,12 @@ void Continuous3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zs
} }
void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) { void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) {
if (xsize < 2 || ysize < 2 || zsize < 2) if (periodic) {
if (xsize < 3 || ysize < 3 || zsize < 3)
throw OpenMMException("Continuous3DFunction: must have at least three points along each axis if periodic");
// Note: value-matching at boundary is eventually checked at 3D-spline creation.
}
else if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis"); throw OpenMMException("Continuous3DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize*zsize) if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Continuous3DFunction: incorrect number of values"); throw OpenMMException("Continuous3DFunction: incorrect number of values");
......
...@@ -238,11 +238,19 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -238,11 +238,19 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
for (auto& suffix : suffixes) { for (auto& suffix : suffixes) {
out << "{\n"; out << "{\n";
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) { if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
int periodic = functionParams[i][4];
out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
if (periodic) {
out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[5]<< ";\n";
out << "x = (x - floor(x))*" << paramsFloat[6] << ";\n";
out << "int index = (int) (floor(x));\n";
}
else {
out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n"; out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n"; out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n"; out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) " << paramsInt[3] << ");\n"; out << "index = min(index, (int) " << paramsInt[3] << ");\n";
}
out << "float4 coeff = " << functionNames[i].second << "[index];\n"; out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n"; out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n"; out << "real a = 1.0f-b;\n";
...@@ -253,16 +261,28 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -253,16 +261,28 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
else else
out << nodeNames[j] << suffix << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n"; out << nodeNames[j] << suffix << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
} }
if (!periodic)
out << "}\n"; out << "}\n";
} }
else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
int periodic = functionParams[i][8];
out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n"; out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n";
if (periodic) {
out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[9] << ";\n";
out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[10] << ";\n";
out << "x = (x - floor(x))*" << paramsFloat[0] << ";\n";
out << "y = (y - floor(y))*" << paramsFloat[1] << ";\n";
out << "int s = (int) floor(x);\n";
out << "int t = (int) floor(y);\n";
}
else {
out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n"; out << "if (x >= " << paramsFloat[2] << " && x <= " << paramsFloat[3] << " && y >= " << paramsFloat[4] << " && y <= " << paramsFloat[5] << ") {\n";
out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n"; out << "x = (x - " << paramsFloat[2] << ")*" << paramsFloat[6] << ";\n";
out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n"; out << "y = (y - " << paramsFloat[4] << ")*" << paramsFloat[7] << ";\n";
out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n"; out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n"; out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
}
out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n"; out << "int coeffIndex = 4*(s+" << paramsInt[0] << "*t);\n";
out << "float4 c[4];\n"; out << "float4 c[4];\n";
for (int j = 0; j < 4; j++) for (int j = 0; j < 4; j++)
...@@ -294,12 +314,26 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -294,12 +314,26 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
else else
throw OpenMMException("Unsupported derivative order for Continuous2DFunction"); throw OpenMMException("Unsupported derivative order for Continuous2DFunction");
} }
if (!periodic)
out << "}\n"; out << "}\n";
} }
else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
int periodic = functionParams[i][12];
out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n"; out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n";
out << "real z = " << getTempName(node.getChildren()[2], temps) << suffix << ";\n"; out << "real z = " << getTempName(node.getChildren()[2], temps) << suffix << ";\n";
if (periodic) {
out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[13] << ";\n";
out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[14] << ";\n";
out << "z = (z - " << paramsFloat[7] << ")*" << paramsFloat[15] << ";\n";
out << "x = (x - floor(x))*" << paramsFloat[0] << ";\n";
out << "y = (y - floor(y))*" << paramsFloat[1] << ";\n";
out << "z = (z - floor(z))*" << paramsFloat[2] << ";\n";
out << "int s = (int) floor(x);\n";
out << "int t = (int) floor(y);\n";
out << "int u = (int) floor(z);\n";
}
else {
out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n"; out << "if (x >= " << paramsFloat[3] << " && x <= " << paramsFloat[4] << " && y >= " << paramsFloat[5] << " && y <= " << paramsFloat[6] << " && z >= " << paramsFloat[7] << " && z <= " << paramsFloat[8] << ") {\n";
out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n"; out << "x = (x - " << paramsFloat[3] << ")*" << paramsFloat[9] << ";\n";
out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n"; out << "y = (y - " << paramsFloat[5] << ")*" << paramsFloat[10] << ";\n";
...@@ -307,6 +341,7 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -307,6 +341,7 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n"; out << "int s = min((int) floor(x), " << paramsInt[0] << "-1);\n";
out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n"; out << "int t = min((int) floor(y), " << paramsInt[1] << "-1);\n";
out << "int u = min((int) floor(z), " << paramsInt[2] << "-1);\n"; out << "int u = min((int) floor(z), " << paramsInt[2] << "-1);\n";
}
out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n"; out << "int coeffIndex = 16*(s+" << paramsInt[0] << "*(t+" << paramsInt[1] << "*u));\n";
out << "float4 c[16];\n"; out << "float4 c[16];\n";
for (int j = 0; j < 16; j++) for (int j = 0; j < 16; j++)
...@@ -360,6 +395,7 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT ...@@ -360,6 +395,7 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
else else
throw OpenMMException("Unsupported derivative order for Continuous3DFunction"); throw OpenMMException("Unsupported derivative order for Continuous3DFunction");
} }
if (!periodic)
out << "}\n"; out << "}\n";
} }
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
...@@ -720,11 +756,12 @@ vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFu ...@@ -720,11 +756,12 @@ vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFu
vector<double> values; vector<double> values;
double min, max; double min, max;
fn.getFunctionParameters(values, min, max); fn.getFunctionParameters(values, min, max);
bool periodic = fn.getPeriodic();
int numValues = values.size(); int numValues = values.size();
vector<double> x(numValues), derivs; vector<double> x(numValues), derivs;
for (int i = 0; i < numValues; i++) for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1); x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createNaturalSpline(x, values, derivs); SplineFitter::createSpline(x, values, periodic, derivs);
vector<float> f(4*(numValues-1)); vector<float> f(4*(numValues-1));
for (int i = 0; i < (int) values.size()-1; i++) { for (int i = 0; i < (int) values.size()-1; i++) {
f[4*i] = (float) values[i]; f[4*i] = (float) values[i];
...@@ -743,13 +780,14 @@ vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFu ...@@ -743,13 +780,14 @@ vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFu
int xsize, ysize; int xsize, ysize;
double xmin, xmax, ymin, ymax; double xmin, xmax, ymin, ymax;
fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax); fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
bool periodic = fn.getPeriodic();
vector<double> x(xsize), y(ysize); vector<double> x(xsize), y(ysize);
for (int i = 0; i < xsize; i++) for (int i = 0; i < xsize; i++)
x[i] = xmin+i*(xmax-xmin)/(xsize-1); x[i] = xmin+i*(xmax-xmin)/(xsize-1);
for (int i = 0; i < ysize; i++) for (int i = 0; i < ysize; i++)
y[i] = ymin+i*(ymax-ymin)/(ysize-1); y[i] = ymin+i*(ymax-ymin)/(ysize-1);
vector<vector<double> > c; vector<vector<double> > c;
SplineFitter::create2DNaturalSpline(x, y, values, c); SplineFitter::create2DSpline(x, y, values, periodic, c);
vector<float> f(16*c.size()); vector<float> f(16*c.size());
for (int i = 0; i < (int) c.size(); i++) { for (int i = 0; i < (int) c.size(); i++) {
for (int j = 0; j < 16; j++) for (int j = 0; j < 16; j++)
...@@ -766,6 +804,7 @@ vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFu ...@@ -766,6 +804,7 @@ vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFu
int xsize, ysize, zsize; int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax; double xmin, xmax, ymin, ymax, zmin, zmax;
fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax); fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
bool periodic = fn.getPeriodic();
vector<double> x(xsize), y(ysize), z(zsize); vector<double> x(xsize), y(ysize), z(zsize);
for (int i = 0; i < xsize; i++) for (int i = 0; i < xsize; i++)
x[i] = xmin+i*(xmax-xmin)/(xsize-1); x[i] = xmin+i*(xmax-xmin)/(xsize-1);
...@@ -774,7 +813,7 @@ vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFu ...@@ -774,7 +813,7 @@ vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFu
for (int i = 0; i < zsize; i++) for (int i = 0; i < zsize; i++)
z[i] = zmin+i*(zmax-zmin)/(zsize-1); z[i] = zmin+i*(zmax-zmin)/(zsize-1);
vector<vector<double> > c; vector<vector<double> > c;
SplineFitter::create3DNaturalSpline(x, y, z, values, c); SplineFitter::create3DSpline(x, y, z, values, periodic, c);
vector<float> f(64*c.size()); vector<float> f(64*c.size());
for (int i = 0; i < (int) c.size(); i++) { for (int i = 0; i < (int) c.size(); i++) {
for (int j = 0; j < 64; j++) for (int j = 0; j < 64; j++)
...@@ -835,10 +874,14 @@ vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vec ...@@ -835,10 +874,14 @@ vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vec
vector<double> values; vector<double> values;
double min, max; double min, max;
fn.getFunctionParameters(values, min, max); fn.getFunctionParameters(values, min, max);
int periodic = (int) fn.getPeriodic();
params[i].push_back(min); params[i].push_back(min);
params[i].push_back(max); params[i].push_back(max);
params[i].push_back((values.size()-1)/(max-min)); params[i].push_back((values.size()-1)/(max-min));
params[i].push_back(values.size()-2); params[i].push_back(values.size()-2);
params[i].push_back(periodic);
params[i].push_back(1.0/(max-min));
params[i].push_back(values.size()-1);
} }
else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(*functions[i]); const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(*functions[i]);
...@@ -846,6 +889,7 @@ vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vec ...@@ -846,6 +889,7 @@ vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vec
int xsize, ysize; int xsize, ysize;
double xmin, xmax, ymin, ymax; double xmin, xmax, ymin, ymax;
fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax); fn.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
int periodic = (int) fn.getPeriodic();
params[i].push_back(xsize-1); params[i].push_back(xsize-1);
params[i].push_back(ysize-1); params[i].push_back(ysize-1);
params[i].push_back(xmin); params[i].push_back(xmin);
...@@ -854,6 +898,9 @@ vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vec ...@@ -854,6 +898,9 @@ vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vec
params[i].push_back(ymax); params[i].push_back(ymax);
params[i].push_back((xsize-1)/(xmax-xmin)); params[i].push_back((xsize-1)/(xmax-xmin));
params[i].push_back((ysize-1)/(ymax-ymin)); params[i].push_back((ysize-1)/(ymax-ymin));
params[i].push_back(periodic);
params[i].push_back(1.0/(xmax-xmin));
params[i].push_back(1.0/(ymax-ymin));
} }
else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Continuous3DFunction*>(functions[i]) != NULL) {
const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(*functions[i]); const Continuous3DFunction& fn = dynamic_cast<const Continuous3DFunction&>(*functions[i]);
...@@ -861,6 +908,7 @@ vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vec ...@@ -861,6 +908,7 @@ vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vec
int xsize, ysize, zsize; int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax; double xmin, xmax, ymin, ymax, zmin, zmax;
fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax); fn.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
int periodic = (int) fn.getPeriodic();
params[i].push_back(xsize-1); params[i].push_back(xsize-1);
params[i].push_back(ysize-1); params[i].push_back(ysize-1);
params[i].push_back(zsize-1); params[i].push_back(zsize-1);
...@@ -873,6 +921,10 @@ vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vec ...@@ -873,6 +921,10 @@ vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vec
params[i].push_back((xsize-1)/(xmax-xmin)); params[i].push_back((xsize-1)/(xmax-xmin));
params[i].push_back((ysize-1)/(ymax-ymin)); params[i].push_back((ysize-1)/(ymax-ymin));
params[i].push_back((zsize-1)/(zmax-zmin)); params[i].push_back((zsize-1)/(zmax-zmin));
params[i].push_back(periodic);
params[i].push_back(1.0/(xmax-xmin));
params[i].push_back(1.0/(ymax-ymin));
params[i].push_back(1.0/(zmax-zmin));
} }
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]); const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
......
...@@ -59,6 +59,7 @@ private: ...@@ -59,6 +59,7 @@ private:
ReferenceContinuous1DFunction(const ReferenceContinuous1DFunction& other); ReferenceContinuous1DFunction(const ReferenceContinuous1DFunction& other);
const Continuous1DFunction& function; const Continuous1DFunction& function;
double min, max; double min, max;
bool periodic;
std::vector<double> x, values, derivs; std::vector<double> x, values, derivs;
}; };
...@@ -77,6 +78,7 @@ private: ...@@ -77,6 +78,7 @@ private:
const Continuous2DFunction& function; const Continuous2DFunction& function;
int xsize, ysize; int xsize, ysize;
double xmin, xmax, ymin, ymax; double xmin, xmax, ymin, ymax;
bool periodic;
std::vector<double> x, y, values; std::vector<double> x, y, values;
std::vector<std::vector<double> > c; std::vector<std::vector<double> > c;
}; };
...@@ -96,6 +98,7 @@ private: ...@@ -96,6 +98,7 @@ private:
const Continuous3DFunction& function; const Continuous3DFunction& function;
int xsize, ysize, zsize; int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax; double xmin, xmax, ymin, ymax, zmin, zmax;
bool periodic;
std::vector<double> x, y, z, values; std::vector<double> x, y, z, values;
std::vector<std::vector<double> > c; std::vector<std::vector<double> > c;
}; };
......
...@@ -51,6 +51,12 @@ static int round(double x) { ...@@ -51,6 +51,12 @@ static int round(double x) {
#include <cmath> #include <cmath>
#endif #endif
static double wrap(double t, double min, double max) {
double L = max - min;
double s = (t - min)/L;
return min + L*(s - floor(s));
}
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
using Lepton::CustomFunction; using Lepton::CustomFunction;
...@@ -75,15 +81,17 @@ extern "C" OPENMM_EXPORT CustomFunction* createReferenceTabulatedFunction(const ...@@ -75,15 +81,17 @@ extern "C" OPENMM_EXPORT CustomFunction* createReferenceTabulatedFunction(const
} }
ReferenceContinuous1DFunction::ReferenceContinuous1DFunction(const Continuous1DFunction& function) : function(function) { ReferenceContinuous1DFunction::ReferenceContinuous1DFunction(const Continuous1DFunction& function) : function(function) {
periodic = function.getPeriodic();
function.getFunctionParameters(values, min, max); function.getFunctionParameters(values, min, max);
int numValues = values.size(); int numValues = values.size();
x.resize(numValues); x.resize(numValues);
for (int i = 0; i < numValues; i++) for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1); x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createNaturalSpline(x, values, derivs); SplineFitter::createSpline(x, values, periodic, derivs);
} }
ReferenceContinuous1DFunction::ReferenceContinuous1DFunction(const ReferenceContinuous1DFunction& other) : function(other.function) { ReferenceContinuous1DFunction::ReferenceContinuous1DFunction(const ReferenceContinuous1DFunction& other) : function(other.function) {
periodic = function.getPeriodic();
function.getFunctionParameters(values, min, max); function.getFunctionParameters(values, min, max);
x = other.x; x = other.x;
values = other.values; values = other.values;
...@@ -95,14 +103,14 @@ int ReferenceContinuous1DFunction::getNumArguments() const { ...@@ -95,14 +103,14 @@ int ReferenceContinuous1DFunction::getNumArguments() const {
} }
double ReferenceContinuous1DFunction::evaluate(const double* arguments) const { double ReferenceContinuous1DFunction::evaluate(const double* arguments) const {
double t = arguments[0]; double t = periodic ? wrap(arguments[0], min, max) : arguments[0];
if (t < min || t > max) if (t < min || t > max)
return 0.0; return 0.0;
return SplineFitter::evaluateSpline(x, values, derivs, t); return SplineFitter::evaluateSpline(x, values, derivs, t);
} }
double ReferenceContinuous1DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const { double ReferenceContinuous1DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
double t = arguments[0]; double t = periodic ? wrap(arguments[0], min, max) : arguments[0];
if (t < min || t > max) if (t < min || t > max)
return 0.0; return 0.0;
return SplineFitter::evaluateSplineDerivative(x, values, derivs, t); return SplineFitter::evaluateSplineDerivative(x, values, derivs, t);
...@@ -113,6 +121,7 @@ CustomFunction* ReferenceContinuous1DFunction::clone() const { ...@@ -113,6 +121,7 @@ CustomFunction* ReferenceContinuous1DFunction::clone() const {
} }
ReferenceContinuous2DFunction::ReferenceContinuous2DFunction(const Continuous2DFunction& function) : function(function) { ReferenceContinuous2DFunction::ReferenceContinuous2DFunction(const Continuous2DFunction& function) : function(function) {
periodic = function.getPeriodic();
function.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax); function.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
x.resize(xsize); x.resize(xsize);
y.resize(ysize); y.resize(ysize);
...@@ -120,10 +129,11 @@ ReferenceContinuous2DFunction::ReferenceContinuous2DFunction(const Continuous2DF ...@@ -120,10 +129,11 @@ ReferenceContinuous2DFunction::ReferenceContinuous2DFunction(const Continuous2DF
x[i] = xmin+i*(xmax-xmin)/(xsize-1); x[i] = xmin+i*(xmax-xmin)/(xsize-1);
for (int i = 0; i < ysize; i++) for (int i = 0; i < ysize; i++)
y[i] = ymin+i*(ymax-ymin)/(ysize-1); y[i] = ymin+i*(ymax-ymin)/(ysize-1);
SplineFitter::create2DNaturalSpline(x, y, values, c); SplineFitter::create2DSpline(x, y, values, periodic, c);
} }
ReferenceContinuous2DFunction::ReferenceContinuous2DFunction(const ReferenceContinuous2DFunction& other) : function(other.function) { ReferenceContinuous2DFunction::ReferenceContinuous2DFunction(const ReferenceContinuous2DFunction& other) : function(other.function) {
periodic = function.getPeriodic();
function.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax); function.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
x = other.x; x = other.x;
y = other.y; y = other.y;
...@@ -136,20 +146,20 @@ int ReferenceContinuous2DFunction::getNumArguments() const { ...@@ -136,20 +146,20 @@ int ReferenceContinuous2DFunction::getNumArguments() const {
} }
double ReferenceContinuous2DFunction::evaluate(const double* arguments) const { double ReferenceContinuous2DFunction::evaluate(const double* arguments) const {
double u = arguments[0]; double u = periodic ? wrap(arguments[0], xmin, xmax) : arguments[0];
if (u < xmin || u > xmax) if (u < xmin || u > xmax)
return 0.0; return 0.0;
double v = arguments[1]; double v = periodic ? wrap(arguments[1], ymin, ymax) : arguments[1];
if (v < ymin || v > ymax) if (v < ymin || v > ymax)
return 0.0; return 0.0;
return SplineFitter::evaluate2DSpline(x, y, values, c, u, v); return SplineFitter::evaluate2DSpline(x, y, values, c, u, v);
} }
double ReferenceContinuous2DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const { double ReferenceContinuous2DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
double u = arguments[0]; double u = periodic ? wrap(arguments[0], xmin, xmax) : arguments[0];
if (u < xmin || u > xmax) if (u < xmin || u > xmax)
return 0.0; return 0.0;
double v = arguments[1]; double v = periodic ? wrap(arguments[1], ymin, ymax) : arguments[1];
if (v < ymin || v > ymax) if (v < ymin || v > ymax)
return 0.0; return 0.0;
double dx, dy; double dx, dy;
...@@ -166,6 +176,7 @@ CustomFunction* ReferenceContinuous2DFunction::clone() const { ...@@ -166,6 +176,7 @@ CustomFunction* ReferenceContinuous2DFunction::clone() const {
} }
ReferenceContinuous3DFunction::ReferenceContinuous3DFunction(const Continuous3DFunction& function) : function(function) { ReferenceContinuous3DFunction::ReferenceContinuous3DFunction(const Continuous3DFunction& function) : function(function) {
periodic = function.getPeriodic();
function.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax); function.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
x.resize(xsize); x.resize(xsize);
y.resize(ysize); y.resize(ysize);
...@@ -176,10 +187,11 @@ ReferenceContinuous3DFunction::ReferenceContinuous3DFunction(const Continuous3DF ...@@ -176,10 +187,11 @@ ReferenceContinuous3DFunction::ReferenceContinuous3DFunction(const Continuous3DF
y[i] = ymin+i*(ymax-ymin)/(ysize-1); y[i] = ymin+i*(ymax-ymin)/(ysize-1);
for (int i = 0; i < zsize; i++) for (int i = 0; i < zsize; i++)
z[i] = zmin+i*(zmax-zmin)/(zsize-1); z[i] = zmin+i*(zmax-zmin)/(zsize-1);
SplineFitter::create3DNaturalSpline(x, y, z, values, c); SplineFitter::create3DSpline(x, y, z, values, periodic, c);
} }
ReferenceContinuous3DFunction::ReferenceContinuous3DFunction(const ReferenceContinuous3DFunction& other) : function(other.function) { ReferenceContinuous3DFunction::ReferenceContinuous3DFunction(const ReferenceContinuous3DFunction& other) : function(other.function) {
periodic = function.getPeriodic();
function.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax); function.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
x = other.x; x = other.x;
y = other.y; y = other.y;
...@@ -193,26 +205,26 @@ int ReferenceContinuous3DFunction::getNumArguments() const { ...@@ -193,26 +205,26 @@ int ReferenceContinuous3DFunction::getNumArguments() const {
} }
double ReferenceContinuous3DFunction::evaluate(const double* arguments) const { double ReferenceContinuous3DFunction::evaluate(const double* arguments) const {
double u = arguments[0]; double u = periodic ? wrap(arguments[0], xmin, xmax) : arguments[0];
if (u < xmin || u > xmax) if (u < xmin || u > xmax)
return 0.0; return 0.0;
double v = arguments[1]; double v = periodic ? wrap(arguments[1], ymin, ymax) : arguments[1];
if (v < ymin || v > ymax) if (v < ymin || v > ymax)
return 0.0; return 0.0;
double w = arguments[2]; double w = periodic ? wrap(arguments[2], zmin, zmax) : arguments[2];
if (w < zmin || w > zmax) if (w < zmin || w > zmax)
return 0.0; return 0.0;
return SplineFitter::evaluate3DSpline(x, y, z, values, c, u, v, w); return SplineFitter::evaluate3DSpline(x, y, z, values, c, u, v, w);
} }
double ReferenceContinuous3DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const { double ReferenceContinuous3DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
double u = arguments[0]; double u = periodic ? wrap(arguments[0], xmin, xmax) : arguments[0];
if (u < xmin || u > xmax) if (u < xmin || u > xmax)
return 0.0; return 0.0;
double v = arguments[1]; double v = periodic ? wrap(arguments[1], ymin, ymax) : arguments[1];
if (v < ymin || v > ymax) if (v < ymin || v > ymax)
return 0.0; return 0.0;
double w = arguments[2]; double w = periodic ? wrap(arguments[2], zmin, zmax) : arguments[2];
if (w < zmin || w > zmax) if (w < zmin || w > zmax)
return 0.0; return 0.0;
double dx, dy, dz; double dx, dy, dz;
......
...@@ -41,7 +41,7 @@ Continuous1DFunctionProxy::Continuous1DFunctionProxy() : SerializationProxy("Con ...@@ -41,7 +41,7 @@ Continuous1DFunctionProxy::Continuous1DFunctionProxy() : SerializationProxy("Con
} }
void Continuous1DFunctionProxy::serialize(const void* object, SerializationNode& node) const { void Continuous1DFunctionProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1); node.setIntProperty("version", 2);
const Continuous1DFunction& function = *reinterpret_cast<const Continuous1DFunction*>(object); const Continuous1DFunction& function = *reinterpret_cast<const Continuous1DFunction*>(object);
double min, max; double min, max;
vector<double> values; vector<double> values;
...@@ -51,23 +51,26 @@ void Continuous1DFunctionProxy::serialize(const void* object, SerializationNode& ...@@ -51,23 +51,26 @@ void Continuous1DFunctionProxy::serialize(const void* object, SerializationNode&
SerializationNode& valuesNode = node.createChildNode("Values"); SerializationNode& valuesNode = node.createChildNode("Values");
for (auto v : values) for (auto v : values)
valuesNode.createChildNode("Value").setDoubleProperty("v", v); valuesNode.createChildNode("Value").setDoubleProperty("v", v);
node.setBoolProperty("periodic", function.getPeriodic());
} }
void* Continuous1DFunctionProxy::deserialize(const SerializationNode& node) const { void* Continuous1DFunctionProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1) int version = node.getIntProperty("version");
if (!(version == 1 || version == 2))
throw OpenMMException("Unsupported version number"); throw OpenMMException("Unsupported version number");
const SerializationNode& valuesNode = node.getChildNode("Values"); const SerializationNode& valuesNode = node.getChildNode("Values");
vector<double> values; vector<double> values;
for (auto& child : valuesNode.getChildren()) for (auto& child : valuesNode.getChildren())
values.push_back(child.getDoubleProperty("v")); values.push_back(child.getDoubleProperty("v"));
return new Continuous1DFunction(values, node.getDoubleProperty("min"), node.getDoubleProperty("max")); bool periodic = version == 1 ? false : node.getBoolProperty("periodic");
return new Continuous1DFunction(values, node.getDoubleProperty("min"), node.getDoubleProperty("max"), periodic);
} }
Continuous2DFunctionProxy::Continuous2DFunctionProxy() : SerializationProxy("Continuous2DFunction") { Continuous2DFunctionProxy::Continuous2DFunctionProxy() : SerializationProxy("Continuous2DFunction") {
} }
void Continuous2DFunctionProxy::serialize(const void* object, SerializationNode& node) const { void Continuous2DFunctionProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1); node.setIntProperty("version", 2);
const Continuous2DFunction& function = *reinterpret_cast<const Continuous2DFunction*>(object); const Continuous2DFunction& function = *reinterpret_cast<const Continuous2DFunction*>(object);
int xsize, ysize; int xsize, ysize;
double xmin, xmax, ymin, ymax; double xmin, xmax, ymin, ymax;
...@@ -82,24 +85,28 @@ void Continuous2DFunctionProxy::serialize(const void* object, SerializationNode& ...@@ -82,24 +85,28 @@ void Continuous2DFunctionProxy::serialize(const void* object, SerializationNode&
SerializationNode& valuesNode = node.createChildNode("Values"); SerializationNode& valuesNode = node.createChildNode("Values");
for (auto v : values) for (auto v : values)
valuesNode.createChildNode("Value").setDoubleProperty("v", v); valuesNode.createChildNode("Value").setDoubleProperty("v", v);
node.setBoolProperty("periodic", function.getPeriodic());
} }
void* Continuous2DFunctionProxy::deserialize(const SerializationNode& node) const { void* Continuous2DFunctionProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1) int version = node.getIntProperty("version");
if (!(version == 1 || version == 2))
throw OpenMMException("Unsupported version number"); throw OpenMMException("Unsupported version number");
const SerializationNode& valuesNode = node.getChildNode("Values"); const SerializationNode& valuesNode = node.getChildNode("Values");
vector<double> values; vector<double> values;
for (auto& child : valuesNode.getChildren()) for (auto& child : valuesNode.getChildren())
values.push_back(child.getDoubleProperty("v")); values.push_back(child.getDoubleProperty("v"));
bool periodic = version == 1 ? false : node.getBoolProperty("periodic");
return new Continuous2DFunction(node.getIntProperty("xsize"), node.getIntProperty("ysize"), values, return new Continuous2DFunction(node.getIntProperty("xsize"), node.getIntProperty("ysize"), values,
node.getDoubleProperty("xmin"), node.getDoubleProperty("xmax"), node.getDoubleProperty("ymin"), node.getDoubleProperty("ymax")); node.getDoubleProperty("xmin"), node.getDoubleProperty("xmax"),
node.getDoubleProperty("ymin"),node.getDoubleProperty("ymax"), periodic);
} }
Continuous3DFunctionProxy::Continuous3DFunctionProxy() : SerializationProxy("Continuous3DFunction") { Continuous3DFunctionProxy::Continuous3DFunctionProxy() : SerializationProxy("Continuous3DFunction") {
} }
void Continuous3DFunctionProxy::serialize(const void* object, SerializationNode& node) const { void Continuous3DFunctionProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1); node.setIntProperty("version", 2);
const Continuous3DFunction& function = *reinterpret_cast<const Continuous3DFunction*>(object); const Continuous3DFunction& function = *reinterpret_cast<const Continuous3DFunction*>(object);
int xsize, ysize, zsize; int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax; double xmin, xmax, ymin, ymax, zmin, zmax;
...@@ -117,18 +124,21 @@ void Continuous3DFunctionProxy::serialize(const void* object, SerializationNode& ...@@ -117,18 +124,21 @@ void Continuous3DFunctionProxy::serialize(const void* object, SerializationNode&
SerializationNode& valuesNode = node.createChildNode("Values"); SerializationNode& valuesNode = node.createChildNode("Values");
for (auto v : values) for (auto v : values)
valuesNode.createChildNode("Value").setDoubleProperty("v", v); valuesNode.createChildNode("Value").setDoubleProperty("v", v);
node.setBoolProperty("periodic", function.getPeriodic());
} }
void* Continuous3DFunctionProxy::deserialize(const SerializationNode& node) const { void* Continuous3DFunctionProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1) int version = node.getIntProperty("version");
if (!(version == 1 || version == 2))
throw OpenMMException("Unsupported version number"); throw OpenMMException("Unsupported version number");
const SerializationNode& valuesNode = node.getChildNode("Values"); const SerializationNode& valuesNode = node.getChildNode("Values");
vector<double> values; vector<double> values;
for (auto& child : valuesNode.getChildren()) for (auto& child : valuesNode.getChildren())
values.push_back(child.getDoubleProperty("v")); values.push_back(child.getDoubleProperty("v"));
bool periodic = version == 1 ? false : node.getBoolProperty("periodic");
return new Continuous3DFunction(node.getIntProperty("xsize"), node.getIntProperty("ysize"), node.getIntProperty("zsize"), values, return new Continuous3DFunction(node.getIntProperty("xsize"), node.getIntProperty("ysize"), node.getIntProperty("zsize"), values,
node.getDoubleProperty("xmin"), node.getDoubleProperty("xmax"), node.getDoubleProperty("ymin"), node.getDoubleProperty("ymax"), node.getDoubleProperty("xmin"), node.getDoubleProperty("xmax"), node.getDoubleProperty("ymin"), node.getDoubleProperty("ymax"),
node.getDoubleProperty("zmin"), node.getDoubleProperty("zmax")); node.getDoubleProperty("zmin"), node.getDoubleProperty("zmax"), periodic);
} }
Discrete1DFunctionProxy::Discrete1DFunctionProxy() : SerializationProxy("Discrete1DFunction") { Discrete1DFunctionProxy::Discrete1DFunctionProxy() : SerializationProxy("Discrete1DFunction") {
......
...@@ -29,6 +29,9 @@ ...@@ -29,6 +29,9 @@
* USE OR OTHER DEALINGS IN THE SOFTWARE. * * USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#ifdef WIN32
#define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include "openmm/TabulatedFunction.h" #include "openmm/TabulatedFunction.h"
#include "openmm/internal/AssertionUtilities.h" #include "openmm/internal/AssertionUtilities.h"
#include "openmm/serialization/XmlSerializer.h" #include "openmm/serialization/XmlSerializer.h"
...@@ -63,6 +66,36 @@ void testContinuous1DFunction() { ...@@ -63,6 +66,36 @@ void testContinuous1DFunction() {
ASSERT_EQUAL(values.size(), values2.size()); ASSERT_EQUAL(values.size(), values2.size());
for (int j = 0; j < (int) values.size(); j++) for (int j = 0; j < (int) values.size(); j++)
ASSERT_EQUAL(values[j], values2[j]); ASSERT_EQUAL(values[j], values2[j]);
ASSERT(!copy->getPeriodic());
}
void testPeriodicContinuous1DFunction() {
// Create a function.
double min = 0.0, max = 2.0*M_PI;
vector<double> values(60);
for (int i = 0; i < (int) values.size(); i++)
values[i] = sin(2.0*M_PI*i/(values.size()-1));
Continuous1DFunction function(values, min, max, true);
// Serialize and then deserialize it.
stringstream buffer;
XmlSerializer::serialize<Continuous1DFunction>(&function, "Function", buffer);
Continuous1DFunction* copy = XmlSerializer::deserialize<Continuous1DFunction>(buffer);
// Compare the two forces to see if they are identical.
double min2, max2;
vector<double> values2;
copy->getFunctionParameters(values2, min2, max2);
ASSERT_EQUAL(min, min2);
ASSERT_EQUAL(max, max2);
ASSERT_EQUAL(values.size(), values2.size());
for (int j = 0; j < (int) values.size(); j++)
ASSERT_EQUAL(values[j], values2[j]);
ASSERT(copy->getPeriodic());
} }
void testContinuous2DFunction() { void testContinuous2DFunction() {
...@@ -96,6 +129,43 @@ void testContinuous2DFunction() { ...@@ -96,6 +129,43 @@ void testContinuous2DFunction() {
ASSERT_EQUAL(values.size(), values2.size()); ASSERT_EQUAL(values.size(), values2.size());
for (int j = 0; j < (int) values.size(); j++) for (int j = 0; j < (int) values.size(); j++)
ASSERT_EQUAL(values[j], values2[j]); ASSERT_EQUAL(values[j], values2[j]);
ASSERT(!copy->getPeriodic());
}
void testPeriodicContinuous2DFunction() {
// Create a function.
int xsize = 5, ysize = 12;
double xmin = 0.0, xmax = 2.0*M_PI, ymin = 0.0, ymax = 2.0*M_PI;
vector<double> values(xsize*ysize);
for (int i = 0; i < xsize; i++)
for (int j = 0; j < (int) ysize; j++)
values[i+j*xsize] = sin(2.0*M_PI*i/(xsize-1))*cos(2.0*M_PI*j/(ysize-1));
Continuous2DFunction function(xsize, ysize, values, xmin, xmax, ymin, ymax, true);
// Serialize and then deserialize it.
stringstream buffer;
XmlSerializer::serialize<Continuous2DFunction>(&function, "Function", buffer);
Continuous2DFunction* copy = XmlSerializer::deserialize<Continuous2DFunction>(buffer);
// Compare the two forces to see if they are identical.
int xsize2, ysize2;
double xmin2, xmax2, ymin2, ymax2;
vector<double> values2;
copy->getFunctionParameters(xsize2, ysize2, values2, xmin2, xmax2, ymin2, ymax2);
ASSERT_EQUAL(xsize, xsize2);
ASSERT_EQUAL(ysize, ysize2);
ASSERT_EQUAL(xmin, xmin2);
ASSERT_EQUAL(xmax, xmax2);
ASSERT_EQUAL(ymin, ymin2);
ASSERT_EQUAL(ymax, ymax2);
ASSERT_EQUAL(values.size(), values2.size());
for (int j = 0; j < (int) values.size(); j++)
ASSERT_EQUAL(values[j], values2[j]);
ASSERT(copy->getPeriodic());
} }
void testContinuous3DFunction() { void testContinuous3DFunction() {
...@@ -132,6 +202,46 @@ void testContinuous3DFunction() { ...@@ -132,6 +202,46 @@ void testContinuous3DFunction() {
ASSERT_EQUAL(values.size(), values2.size()); ASSERT_EQUAL(values.size(), values2.size());
for (int j = 0; j < (int) values.size(); j++) for (int j = 0; j < (int) values.size(); j++)
ASSERT_EQUAL(values[j], values2[j]); ASSERT_EQUAL(values[j], values2[j]);
ASSERT(!copy->getPeriodic());
}
void testPeriodicContinuous3DFunction() {
// Create a function.
int xsize = 5, ysize = 4, zsize = 3;
double xmin = 0.0, xmax = 2.0*M_PI, ymin = 0.0, ymax = 2.0*M_PI, zmin = 0.0, zmax = 2.0*M_PI;
vector<double> values(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++)
values[i+j*xsize+k*xsize*ysize] = sin(2.0*M_PI*i/(xsize-1))*cos(2.0*M_PI*j/(ysize-1))*sin(2.0*M_PI*k/(zsize-1));
Continuous3DFunction function(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax, true);
// Serialize and then deserialize it.
stringstream buffer;
XmlSerializer::serialize<Continuous3DFunction>(&function, "Function", buffer);
Continuous3DFunction* copy = XmlSerializer::deserialize<Continuous3DFunction>(buffer);
// Compare the two forces to see if they are identical.
int xsize2, ysize2, zsize2;
double xmin2, xmax2, ymin2, ymax2, zmin2, zmax2;
vector<double> values2;
copy->getFunctionParameters(xsize2, ysize2, zsize2, values2, xmin2, xmax2, ymin2, ymax2, zmin2, zmax2);
ASSERT_EQUAL(xsize, xsize2);
ASSERT_EQUAL(ysize, ysize2);
ASSERT_EQUAL(zsize, zsize2);
ASSERT_EQUAL(xmin, xmin2);
ASSERT_EQUAL(xmax, xmax2);
ASSERT_EQUAL(ymin, ymin2);
ASSERT_EQUAL(ymax, ymax2);
ASSERT_EQUAL(zmin, zmin2);
ASSERT_EQUAL(zmax, zmax2);
ASSERT_EQUAL(values.size(), values2.size());
for (int j = 0; j < (int) values.size(); j++)
ASSERT_EQUAL(values[j], values2[j]);
ASSERT(copy->getPeriodic());
} }
void testDiscrete1DFunction() { void testDiscrete1DFunction() {
...@@ -215,8 +325,11 @@ void testDiscrete3DFunction() { ...@@ -215,8 +325,11 @@ void testDiscrete3DFunction() {
int main() { int main() {
try { try {
testContinuous1DFunction(); testContinuous1DFunction();
testPeriodicContinuous1DFunction();
testContinuous2DFunction(); testContinuous2DFunction();
testPeriodicContinuous2DFunction();
testContinuous3DFunction(); testContinuous3DFunction();
testPeriodicContinuous3DFunction();
testDiscrete1DFunction(); testDiscrete1DFunction();
testDiscrete2DFunction(); testDiscrete2DFunction();
testDiscrete3DFunction(); testDiscrete3DFunction();
......
...@@ -329,7 +329,8 @@ void testContinuous1DFunction() { ...@@ -329,7 +329,8 @@ void testContinuous1DFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addTabulatedFunction("fn", new Continuous1DFunction(table, 1.0, 6.0)); Continuous1DFunction* continuous1DFunction = new Continuous1DFunction(table, 1.0, 6.0);
forceField->addTabulatedFunction("fn", continuous1DFunction);
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -356,6 +357,46 @@ void testContinuous1DFunction() { ...@@ -356,6 +357,46 @@ void testContinuous1DFunction() {
} }
} }
void testPeriodicContinuous1DFunction() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r)+1");
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
int xsize = 20;
vector<double> table(xsize);
for (int i = 0; i < xsize; i++)
table[i] = sin(2.0*M_PI*i/(xsize-1));
Continuous1DFunction* continuous1DFunction = new Continuous1DFunction(table, 1.0, 2.0*M_PI+1.0, true);
forceField->addTabulatedFunction("fn", continuous1DFunction);
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 1; i < 30; i++) {
double x = (7.0/30.0)*i;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double force = -cos(x-1.0);
double energy = sin(x-1.0)+1.0;
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
}
for (int i = 1; i < xsize; i++) {
double x = 2.0*M_PI*i/(xsize-1)+1.0;
positions[1] = Vec3(x, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Energy);
double energy = sin(x-1.0)+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
}
}
void testContinuous2DFunction() { void testContinuous2DFunction() {
const int xsize = 20; const int xsize = 20;
const int ysize = 21; const int ysize = 21;
...@@ -404,6 +445,50 @@ void testContinuous2DFunction() { ...@@ -404,6 +445,50 @@ void testContinuous2DFunction() {
} }
} }
void testPeriodicContinuous2DFunction() {
const int xsize = 20;
const int ysize = 21;
const double xmin = 1.0;
const double xmax = 1.0+8.0*M_PI;
const double ymin = 0.0;
const double ymax = 2.0*M_PI;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
double x = xmin + i*(xmax-xmin)/(xsize-1);
double y = ymin + j*(ymax-ymin)/(ysize-1);
table[i+xsize*j] = sin(0.25*x)*cos(y);
}
}
forceField->addTabulatedFunction("fn", new Continuous2DFunction(xsize, ysize, table, xmin, xmax, ymin, ymax, true));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin-0.15; x < xmax+0.2; x += 1.0) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.5) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = sin(0.25*x)*cos(y)+1.0;
double force = -0.25*cos(0.25*x)*cos(y);
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
}
}
}
void testContinuous3DFunction() { void testContinuous3DFunction() {
const int xsize = 10; const int xsize = 10;
const int ysize = 11; const int ysize = 11;
...@@ -462,6 +547,60 @@ void testContinuous3DFunction() { ...@@ -462,6 +547,60 @@ void testContinuous3DFunction() {
} }
} }
void testPeriodicContinuous3DFunction() {
const int xsize = 10;
const int ysize = 11;
const int zsize = 12;
const double xmin = 1.0;
const double xmax = 1.0+8.0*M_PI;
const double ymin = 0.0;
const double ymax = 2.0*M_PI;
const double zmin = 0.0;
const double zmax = 2.0*M_PI;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) {
double x = xmin + i*(xmax-xmin)/(xsize-1);
double y = ymin + j*(ymax-ymin)/(ysize-1);
double z = zmin + k*(zmax-zmin)/(zsize-1);
table[i+xsize*j+xsize*ysize*k] = sin(0.25*x)*cos(y)*(1.0-sin(z));
}
}
}
forceField->addTabulatedFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax, true));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin; x < xmax+0.2; x += 1.0) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.5) {
for (double z = zmin-0.15; z < zmax+0.2; z += 0.5) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setParameter("b", z);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = sin(0.25*x)*cos(y)*(1.0-sin(z))+1.0;
double force = -0.25*cos(0.25*x)*cos(y)*(1.0-sin(z));
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
}
void testDiscrete1DFunction() { void testDiscrete1DFunction() {
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
...@@ -1303,8 +1442,11 @@ int main(int argc, char* argv[]) { ...@@ -1303,8 +1442,11 @@ int main(int argc, char* argv[]) {
testPeriodic(); testPeriodic();
testTriclinic(); testTriclinic();
testContinuous1DFunction(); testContinuous1DFunction();
testPeriodicContinuous1DFunction();
testContinuous2DFunction(); testContinuous2DFunction();
testPeriodicContinuous2DFunction();
testContinuous3DFunction(); testContinuous3DFunction();
testPeriodicContinuous3DFunction();
testDiscrete1DFunction(); testDiscrete1DFunction();
testDiscrete2DFunction(); testDiscrete2DFunction();
testDiscrete3DFunction(); testDiscrete3DFunction();
......
...@@ -134,6 +134,40 @@ void test2DSpline() { ...@@ -134,6 +134,40 @@ void test2DSpline() {
} }
} }
void testPeriodic2DSpline() {
const int xsize = 15;
const int ysize = 17;
vector<double> x(xsize);
vector<double> y(ysize);
vector<double> f(xsize*ysize);
for (int i = 0; i < xsize; i++)
x[i] = 2.0*M_PI*i/(xsize-1);
for (int i = 0; i < ysize; i++)
y[i] = 2.0*M_PI*i/(ysize-1);
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
f[i+j*xsize] = sin(x[i])*cos(y[j]);
vector<vector<double> > c;
SplineFitter::create2DSpline(x, y, f, true, c);
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++) {
double value = SplineFitter::evaluate2DSpline(x, y, f, c, x[i], y[j]);
ASSERT_EQUAL_TOL(f[i+j*xsize], value, 1e-6);
}
for (int i = 0; i < 10; i++) {
for (int j = 0; j < 10; j++) {
double s = x[0]+(i+1)*(x[xsize-1]-x[0])/12.0;
double t = y[0]+(j+1)*(y[ysize-1]-y[0])/12.0;
double value = SplineFitter::evaluate2DSpline(x, y, f, c, s, t);
ASSERT_EQUAL_TOL(sin(s)*cos(t), value, 0.02);
double dx, dy;
SplineFitter::evaluate2DSplineDerivatives(x, y, f, c, s, t, dx, dy);
ASSERT_EQUAL_TOL(cos(s)*cos(t), dx, 0.05);
ASSERT_EQUAL_TOL(-sin(s)*sin(t), dy, 0.05);
}
}
}
void test3DSpline() { void test3DSpline() {
const int xsize = 8; const int xsize = 8;
const int ysize = 9; const int ysize = 9;
...@@ -179,12 +213,59 @@ void test3DSpline() { ...@@ -179,12 +213,59 @@ void test3DSpline() {
} }
} }
void testPeriodic3DSpline() {
const int xsize = 11;
const int ysize = 13;
const int zsize = 15;
vector<double> x(xsize);
vector<double> y(ysize);
vector<double> z(zsize);
vector<double> f(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++)
x[i] = 2.0*M_PI*i/(xsize-1);
for (int i = 0; i < ysize; i++)
y[i] = 2.0*M_PI*i/(ysize-1);
for (int i = 0; i < zsize; i++)
z[i] = 2.0*M_PI*i/(zsize-1);
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++)
f[i+j*xsize+k*xsize*ysize] = sin(x[i])*cos(y[j])*(1.0-sin(z[k]));
vector<vector<double> > c;
SplineFitter::create3DSpline(x, y, z, f, true, c);
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) {
double value = SplineFitter::evaluate3DSpline(x, y, z, f, c, x[i], y[j], z[k]);
ASSERT_EQUAL_TOL(f[i+j*xsize+k*xsize*ysize], value, 1e-6);
}
}
for (int i = 0; i < 10; i++) {
for (int j = 0; j < 10; j++) {
for (int k = 0; k < 10; k++) {
double s = x[0]+(i+1)*(x[xsize-1]-x[0])/12.0;
double t = y[0]+(j+1)*(y[ysize-1]-y[0])/12.0;
double u = z[0]+(k+1)*(z[zsize-1]-z[0])/12.0;
double value = SplineFitter::evaluate3DSpline(x, y, z, f, c, s, t, u);
ASSERT_EQUAL_TOL(sin(s)*cos(t)*(1.0-sin(u)), value, 0.02);
double dx, dy, dz;
SplineFitter::evaluate3DSplineDerivatives(x, y, z, f, c, s, t, u, dx, dy, dz);
ASSERT_EQUAL_TOL(cos(s)*cos(t)*(1.0-sin(u)), dx, 0.1);
ASSERT_EQUAL_TOL(-sin(s)*sin(t)*(1.0-sin(u)), dy, 0.1);
ASSERT_EQUAL_TOL(-sin(s)*cos(t)*cos(u), dz, 0.1);
}
}
}
}
int main() { int main() {
try { try {
testNaturalSpline(); testNaturalSpline();
testPeriodicSpline(); testPeriodicSpline();
test2DSpline(); test2DSpline();
testPeriodic2DSpline();
test3DSpline(); test3DSpline();
testPeriodic3DSpline();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
...@@ -86,6 +86,12 @@ def _parseFunctions(element): ...@@ -86,6 +86,12 @@ def _parseFunctions(element):
params[key] = int(function.attrib[key]) params[key] = int(function.attrib[key])
elif key.endswith('min') or key.endswith('max'): elif key.endswith('min') or key.endswith('max'):
params[key] = float(function.attrib[key]) params[key] = float(function.attrib[key])
if functionType.startswith('Continuous'):
periodicStr = function.attrib.get('periodic', 'false').lower()
if periodicStr in ['true', 'false', 'yes', 'no', '1', '0']:
params['periodic'] = periodicStr in ['true', 'yes', '1']
else:
raise ValueError('ForceField: non-boolean value for periodic attribute in tabulated function definition')
functions.append((function.attrib['name'], functionType, values, params)) functions.append((function.attrib['name'], functionType, values, params))
return functions return functions
...@@ -93,11 +99,33 @@ def _createFunctions(force, functions): ...@@ -93,11 +99,33 @@ def _createFunctions(force, functions):
"""Add TabulatedFunctions to a Force based on the information that was recorded by _parseFunctions().""" """Add TabulatedFunctions to a Force based on the information that was recorded by _parseFunctions()."""
for (name, type, values, params) in functions: for (name, type, values, params) in functions:
if type == 'Continuous1D': if type == 'Continuous1D':
force.addTabulatedFunction(name, mm.Continuous1DFunction(values, params['min'], params['max'])) force.addTabulatedFunction(
name,
mm.Continuous1DFunction(values, params['min'], params['max'], params['periodic']),
)
elif type == 'Continuous2D': elif type == 'Continuous2D':
force.addTabulatedFunction(name, mm.Continuous2DFunction(params['xsize'], params['ysize'], values, params['xmin'], params['xmax'], params['ymin'], params['ymax'])) force.addTabulatedFunction(
name,
mm.Continuous2DFunction(
params['xsize'], params['ysize'],
values,
params['xmin'], params['xmax'],
params['ymin'], params['ymax'],
params['periodic'],
),
)
elif type == 'Continuous3D': elif type == 'Continuous3D':
force.addTabulatedFunction(name, mm.Continuous2DFunction(params['xsize'], params['ysize'], params['zsize'], values, params['xmin'], params['xmax'], params['ymin'], params['ymax'], params['zmin'], params['zmax'])) force.addTabulatedFunction(
name,
mm.Continuous2DFunction(
params['xsize'], params['ysize'], params['zsize'],
values,
params['xmin'], params['xmax'],
params['ymin'], params['ymax'],
params['zmin'], params['zmax'],
params['periodic'],
),
)
elif type == 'Discrete1D': elif type == 'Discrete1D':
force.addTabulatedFunction(name, mm.Discrete1DFunction(values)) force.addTabulatedFunction(name, mm.Discrete1DFunction(values))
elif type == 'Discrete2D': elif type == 'Discrete2D':
...@@ -3187,7 +3215,7 @@ class CustomManyParticleGenerator(object): ...@@ -3187,7 +3215,7 @@ class CustomManyParticleGenerator(object):
force.setTypeFilter(index, types) force.setTypeFilter(index, types)
for (name, type, values, params) in self.functions: for (name, type, values, params) in self.functions:
if type == 'Continuous1D': if type == 'Continuous1D':
force.addTabulatedFunction(name, mm.Continuous1DFunction(values, params['min'], params['max'])) force.addTabulatedFunction(name, mm.Continuous1DFunction(values, params['min'], params['max'], params['periodic']))
elif type == 'Continuous2D': elif type == 'Continuous2D':
force.addTabulatedFunction(name, mm.Continuous2DFunction(params['xsize'], params['ysize'], values, params['xmin'], params['xmax'], params['ymin'], params['ymax'])) force.addTabulatedFunction(name, mm.Continuous2DFunction(params['xsize'], params['ysize'], values, params['xmin'], params['xmax'], params['ymin'], params['ymax']))
elif type == 'Continuous3D': elif type == 'Continuous3D':
......
...@@ -120,27 +120,31 @@ class Metadynamics(object): ...@@ -120,27 +120,31 @@ class Metadynamics(object):
self.saveFrequency = saveFrequency self.saveFrequency = saveFrequency
self._id = np.random.randint(0x7FFFFFFF) self._id = np.random.randint(0x7FFFFFFF)
self._saveIndex = 0 self._saveIndex = 0
self._selfBias = np.zeros(tuple(v.gridWidth for v in variables)) self._selfBias = np.zeros(tuple(v.gridWidth for v in reversed(variables)))
self._totalBias = np.zeros(tuple(v.gridWidth for v in variables)) self._totalBias = np.zeros(tuple(v.gridWidth for v in reversed(variables)))
self._loadedBiases = {} self._loadedBiases = {}
self._deltaT = temperature*(biasFactor-1) self._deltaT = temperature*(biasFactor-1)
varNames = ['cv%d' % i for i in range(len(variables))] varNames = ['cv%d' % i for i in range(len(variables))]
self._force = mm.CustomCVForce('table(%s)' % ', '.join(varNames)) self._force = mm.CustomCVForce('table(%s)' % ', '.join(varNames))
for name, var in zip(varNames, variables): for name, var in zip(varNames, variables):
self._force.addCollectiveVariable(name, var.force) self._force.addCollectiveVariable(name, var.force)
widths = [v.gridWidth for v in variables] self._widths = [v.gridWidth for v in variables]
mins = [v.minValue for v in variables] self._limits = sum(([v.minValue, v.maxValue] for v in variables), [])
maxs = [v.maxValue for v in variables] numPeriodics = sum(v.periodic for v in variables)
if numPeriodics not in [0, len(variables)]:
raise ValueError('Metadynamics cannot handle mixed periodic/non-periodic variables')
periodic = numPeriodics == len(variables)
if len(variables) == 1: if len(variables) == 1:
self._table = mm.Continuous1DFunction(self._totalBias.flatten(), mins[0], maxs[0]) self._table = mm.Continuous1DFunction(self._totalBias.flatten(), *self._limits, periodic)
elif len(variables) == 2: elif len(variables) == 2:
self._table = mm.Continuous2DFunction(widths[0], widths[1], self._totalBias.flatten(), mins[0], maxs[0], mins[1], maxs[1]) self._table = mm.Continuous2DFunction(*self._widths, self._totalBias.flatten(), *self._limits, periodic)
elif len(variables) == 3: elif len(variables) == 3:
self._table = mm.Continuous3DFunction(widths[0], widths[1], widths[2], self._totalBias.flatten(), mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2]) self._table = mm.Continuous3DFunction(*self._widths, self._totalBias.flatten(), *self._limits, periodic)
else: else:
raise ValueError('Metadynamics requires 1, 2, or 3 collective variables') raise ValueError('Metadynamics requires 1, 2, or 3 collective variables')
self._force.addTabulatedFunction('table', self._table) self._force.addTabulatedFunction('table', self._table)
self._force.setForceGroup(31) freeGroups = set(range(32)) - set(force.getForceGroup() for force in system.getForces())
self._force.setForceGroup(max(freeGroups))
system.addForce(self._force) system.addForce(self._force)
self._syncWithDisk() self._syncWithDisk()
...@@ -196,7 +200,8 @@ class Metadynamics(object): ...@@ -196,7 +200,8 @@ class Metadynamics(object):
dist = np.abs(np.linspace(0, 1.0, num=v.gridWidth) - x) dist = np.abs(np.linspace(0, 1.0, num=v.gridWidth) - x)
if v.periodic: if v.periodic:
dist = np.min(np.array([dist, np.abs(dist-1)]), axis=0) dist = np.min(np.array([dist, np.abs(dist-1)]), axis=0)
axisGaussians.append(np.exp(-dist*dist*v.gridWidth/v.biasWidth)) dist[-1] = dist[0]
axisGaussians.append(np.exp(-0.5*dist*dist/v._scaledVariance))
# Compute their outer product. # Compute their outer product.
...@@ -210,15 +215,10 @@ class Metadynamics(object): ...@@ -210,15 +215,10 @@ class Metadynamics(object):
height = height.value_in_unit(unit.kilojoules_per_mole) height = height.value_in_unit(unit.kilojoules_per_mole)
self._selfBias += height*gaussian self._selfBias += height*gaussian
self._totalBias += height*gaussian self._totalBias += height*gaussian
widths = [v.gridWidth for v in self.variables]
mins = [v.minValue for v in self.variables]
maxs = [v.maxValue for v in self.variables]
if len(self.variables) == 1: if len(self.variables) == 1:
self._table.setFunctionParameters(self._totalBias.flatten(), mins[0], maxs[0]) self._table.setFunctionParameters(self._totalBias.flatten(), *self._limits)
elif len(self.variables) == 2: else:
self._table.setFunctionParameters(widths[0], widths[1], self._totalBias.flatten(), mins[0], maxs[0], mins[1], maxs[1]) self._table.setFunctionParameters(*self._widths, self._totalBias.flatten(), *self._limits)
elif len(self.variables) == 3:
self._table.setFunctionParameters(widths[0], widths[1], widths[2], self._totalBias.flatten(), mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2])
self._force.updateParametersInContext(context) self._force.updateParametersInContext(context)
def _syncWithDisk(self): def _syncWithDisk(self):
...@@ -275,29 +275,37 @@ class BiasVariable(object): ...@@ -275,29 +275,37 @@ class BiasVariable(object):
---------- ----------
force: Force force: Force
the Force object whose potential energy defines the collective variable the Force object whose potential energy defines the collective variable
minValue: float minValue: float or unit.Quantity
the minimum value the collective variable can take. If it should ever go below this, the minimum value the collective variable can take. If it should ever go below this,
the bias force will be set to 0. the bias force will be set to 0.
maxValue: float maxValue: float or unit.Quantity
the maximum value the collective variable can take. If it should ever go above this, the maximum value the collective variable can take. If it should ever go above this,
the bias force will be set to 0. the bias force will be set to 0.
biasWidth: float biasWidth: float or unit.Quantity
the width (standard deviation) of the Gaussians added to the bias during metadynamics the width (standard deviation) of the Gaussians added to the bias during metadynamics
periodic: bool periodic: bool (optional)
whether this is a periodic variable, such that minValue and maxValue are physical equivalent whether this is a periodic variable, such that minValue and maxValue are physical equivalent
gridWidth: int gridWidth: int (optional)
the number of grid points to use when tabulating the bias function. If this is omitted, the number of grid points to use when tabulating the bias function. If this is omitted,
a reasonable value is chosen automatically. a reasonable value is chosen automatically.
""" """
self.force = force self.force = force
self.minValue = minValue self.minValue = self._standardize(minValue)
self.maxValue = maxValue self.maxValue = self._standardize(maxValue)
self.biasWidth = biasWidth self.biasWidth = self._standardize(biasWidth)
if not isinstance(periodic, bool):
raise ValueError("BiasVariable: invalid argument")
self.periodic = periodic self.periodic = periodic
if gridWidth is None: if gridWidth is None:
self.gridWidth = int(np.ceil(5*(maxValue-minValue)/biasWidth)) self.gridWidth = int(np.ceil(5*(maxValue-minValue)/biasWidth))
else: else:
self.gridWidth = gridWidth self.gridWidth = gridWidth
self._scaledVariance = (self.biasWidth/(self.maxValue-self.minValue))**2
def _standardize(self, quantity):
if unit.is_quantity(quantity):
return quantity.value_in_unit_system(unit.md_unit_system)
else:
return quantity
_LoadedBias = namedtuple('LoadedBias', ['id', 'index', 'bias']) _LoadedBias = namedtuple('LoadedBias', ['id', 'index', 'bias'])
...@@ -16,9 +16,10 @@ class TestMetadynamics(unittest.TestCase): ...@@ -16,9 +16,10 @@ class TestMetadynamics(unittest.TestCase):
system.addForce(force) system.addForce(force)
cv = CustomBondForce('r') cv = CustomBondForce('r')
cv.addBond(0, 1) cv.addBond(0, 1)
bias = BiasVariable(cv, 0.94, 1.06, 0.02) bias = BiasVariable(cv, 0.94, 1.06, 0.00431, gridWidth=31)
meta = Metadynamics(system, [bias], 300*kelvin, 3.0, 5.0, 10) meta = Metadynamics(system, [bias], 300*kelvin, 3.0, 5.0, 10)
integrator = LangevinIntegrator(300*kelvin, 10/picosecond, 0.001*picosecond) integrator = LangevinIntegrator(300*kelvin, 10/picosecond, 0.001*picosecond)
integrator.setRandomNumberSeed(4321)
topology = Topology() topology = Topology()
chain = topology.addChain() chain = topology.addChain()
residue = topology.addResidue('H2', chain) residue = topology.addResidue('H2', chain)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment