Commit c9c218ac authored by peastman's avatar peastman
Browse files

Created reference implementation of Continuous3DFunction

parent 9d3636f4
......@@ -157,6 +157,77 @@ private:
double xmin, xmax, ymin, ymax;
};
/**
* This is a TabulatedFunction that computes a continuous three dimensional function.
*/
class OPENMM_EXPORT Continuous3DFunction : public TabulatedFunction {
public:
/**
* Create a Continuous3DFunction f(x,y,z) based on a set of tabulated values.
*
* @param values the tabulated values of the function f(x,y,z) at xsize uniformly spaced values of x between xmin
* and xmax, ysize values of y between ymin and ymax, and zsize values of z between zmin and zmax.
* A natural cubic spline is used to interpolate between the tabulated values. The function is
* assumed to be zero when x, y, or z is outside its specified range. The values should be ordered so
* that values[i+xsize*j+xsize*ysize*k] = f(x_i,y_j,z_k), where x_i is the i'th uniformly spaced value of x.
* This must be of length xsize*ysize*zsize.
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param ysize the number of table elements along the z direction
* @param xmin the value of x corresponding to the first element of values
* @param xmax the value of x corresponding to the last element of values
* @param ymin the value of y corresponding to the first element of values
* @param ymax the value of y corresponding to the last element of values
* @param zmin the value of z corresponding to the first element of values
* @param zmax the value of z corresponding to the last element of values
*/
Continuous3DFunction(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax);
/**
* Get the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x,y,z) at xsize uniformly spaced values of x between xmin
* and xmax, ysize values of y between ymin and ymax, and zsize values of z between zmin and zmax.
* A natural cubic spline is used to interpolate between the tabulated values. The function is
* assumed to be zero when x, y, or z is outside its specified range. The values should be ordered so
* that values[i+xsize*j+xsize*ysize*k] = f(x_i,y_j,z_k), where x_i is the i'th uniformly spaced value of x.
* This must be of length xsize*ysize*zsize.
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param ysize the number of table elements along the z direction
* @param xmin the value of x corresponding to the first element of values
* @param xmax the value of x corresponding to the last element of values
* @param ymin the value of y corresponding to the first element of values
* @param ymax the value of y corresponding to the last element of values
* @param zmin the value of z corresponding to the first element of values
* @param zmax the value of z corresponding to the last element of values
*/
void getFunctionParameters(int& xsize, int& ysize, int& zsize, std::vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax, double& zmin, double& zmax) const;
/**
* Set the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x,y,z) at xsize uniformly spaced values of x between xmin
* and xmax, ysize values of y between ymin and ymax, and zsize values of z between zmin and zmax.
* A natural cubic spline is used to interpolate between the tabulated values. The function is
* assumed to be zero when x, y, or z is outside its specified range. The values should be ordered so
* that values[i+xsize*j+xsize*ysize*k] = f(x_i,y_j,z_k), where x_i is the i'th uniformly spaced value of x.
* This must be of length xsize*ysize*zsize.
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param ysize the number of table elements along the z direction
* @param xmin the value of x corresponding to the first element of values
* @param xmax the value of x corresponding to the last element of values
* @param ymin the value of y corresponding to the first element of values
* @param ymax the value of y corresponding to the last element of values
* @param zmin the value of z corresponding to the first element of values
* @param zmax the value of z corresponding to the last element of values
*/
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax);
private:
std::vector<double> values;
int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax;
};
/**
* This is a TabulatedFunction that computes a discrete one dimensional function f(x).
* To evaluate it, x is rounded to the nearest integer and the table element with that
......
......@@ -107,6 +107,65 @@ void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vec
this->ymax = ymax;
}
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) {
if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Continuous3DFunction: incorrect number of values");
if (xmax <= xmin)
throw OpenMMException("Continuous3DFunction: xmax <= xmin for a tabulated function.");
if (ymax <= ymin)
throw OpenMMException("Continuous3DFunction: ymax <= ymin for a tabulated function.");
if (zmax <= zmin)
throw OpenMMException("Continuous3DFunction: zmax <= zmin for a tabulated function.");
this->values = values;
this->xsize = xsize;
this->ysize = ysize;
this->zsize = zsize;
this->xmin = xmin;
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
this->zmin = zmin;
this->zmax = zmax;
}
void Continuous3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zsize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax, double& zmin, double& zmax) const {
values = this->values;
xsize = this->xsize;
ysize = this->ysize;
zsize = this->zsize;
xmin = this->xmin;
xmax = this->xmax;
ymin = this->ymin;
ymax = this->ymax;
zmin = this->zmin;
zmax = this->zmax;
}
void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) {
if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Continuous3DFunction: incorrect number of values");
if (xmax <= xmin)
throw OpenMMException("Continuous3DFunction: xmax <= xmin for a tabulated function.");
if (ymax <= ymin)
throw OpenMMException("Continuous3DFunction: ymax <= ymin for a tabulated function.");
if (zmax <= zmin)
throw OpenMMException("Continuous3DFunction: zmax <= zmin for a tabulated function.");
this->values = values;
this->xsize = xsize;
this->ysize = ysize;
this->zsize = zsize;
this->xmin = xmin;
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
this->zmin = zmin;
this->zmax = zmax;
}
Discrete1DFunction::Discrete1DFunction(const vector<double>& values) {
this->values = values;
}
......
......@@ -78,6 +78,24 @@ private:
std::vector<std::vector<double> > c;
};
/**
* This class adapts a Continuous3DFunction into a Lepton::CustomFunction.
*/
class OPENMM_EXPORT ReferenceContinuous3DFunction : public Lepton::CustomFunction {
public:
ReferenceContinuous3DFunction(const Continuous3DFunction& function);
int getNumArguments() const;
double evaluate(const double* arguments) const;
double evaluateDerivative(const double* arguments, const int* derivOrder) const;
CustomFunction* clone() const;
private:
const Continuous3DFunction& function;
int xsize, ysize, zsize;
double xmin, xmax, ymin, ymax, zmin, zmax;
std::vector<double> x, y, z, values;
std::vector<std::vector<double> > c;
};
/**
* This class adapts a Discrete1DFunction into a Lepton::CustomFunction.
*/
......
......@@ -43,6 +43,8 @@ extern "C" CustomFunction* createReferenceTabulatedFunction(const TabulatedFunct
return new ReferenceContinuous1DFunction(dynamic_cast<const Continuous1DFunction&>(function));
if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL)
return new ReferenceContinuous2DFunction(dynamic_cast<const Continuous2DFunction&>(function));
if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL)
return new ReferenceContinuous3DFunction(dynamic_cast<const Continuous3DFunction&>(function));
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
return new ReferenceDiscrete1DFunction(dynamic_cast<const Discrete1DFunction&>(function));
if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL)
......@@ -128,6 +130,62 @@ CustomFunction* ReferenceContinuous2DFunction::clone() const {
return new ReferenceContinuous2DFunction(function);
}
ReferenceContinuous3DFunction::ReferenceContinuous3DFunction(const Continuous3DFunction& function) : function(function) {
function.getFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
x.resize(xsize);
y.resize(ysize);
z.resize(zsize);
for (int i = 0; i < xsize; i++)
x[i] = xmin+i*(xmax-xmin)/(xsize-1);
for (int i = 0; i < ysize; i++)
y[i] = ymin+i*(ymax-ymin)/(ysize-1);
for (int i = 0; i < zsize; i++)
z[i] = zmin+i*(zmax-zmin)/(zsize-1);
SplineFitter::create3DNaturalSpline(x, y, z, values, c);
}
int ReferenceContinuous3DFunction::getNumArguments() const {
return 3;
}
double ReferenceContinuous3DFunction::evaluate(const double* arguments) const {
double u = arguments[0];
if (u < xmin || u > xmax)
return 0.0;
double v = arguments[1];
if (v < ymin || v > ymax)
return 0.0;
double w = arguments[2];
if (w < zmin || w > zmax)
return 0.0;
return SplineFitter::evaluate3DSpline(x, y, z, values, c, u, v, w);
}
double ReferenceContinuous3DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
double u = arguments[0];
if (u < xmin || u > xmax)
return 0.0;
double v = arguments[1];
if (v < ymin || v > ymax)
return 0.0;
double w = arguments[2];
if (w < zmin || w > zmax)
return 0.0;
double dx, dy, dz;
SplineFitter::evaluate3DSplineDerivatives(x, y, z, values, c, u, v, w, dx, dy, dz);
if (derivOrder[0] == 1 && derivOrder[1] == 0 && derivOrder[2] == 0)
return dx;
if (derivOrder[0] == 0 && derivOrder[1] == 1 && derivOrder[2] == 0)
return dy;
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 1)
return dz;
throw OpenMMException("ReferenceContinuous3DFunction: Unsupported derivative order");
}
CustomFunction* ReferenceContinuous3DFunction::clone() const {
return new ReferenceContinuous3DFunction(function);
}
ReferenceDiscrete1DFunction::ReferenceDiscrete1DFunction(const Discrete1DFunction& function) : function(function) {
function.getFunctionParameters(values);
}
......
......@@ -315,6 +315,65 @@ void testContinuous2DFunction() {
}
}
void testContinuous3DFunction() {
const int xsize = 10;
const int ysize = 11;
const int zsize = 12;
const double xmin = 0.4;
const double xmax = 1.1;
const double ymin = 0.0;
const double ymax = 0.9;
const double zmin = 0.2;
const double zmax = 1.3;
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table(xsize*ysize*zsize);
for (int i = 0; i < xsize; i++) {
for (int j = 0; j < ysize; j++) {
for (int k = 0; k < zsize; k++) {
double x = xmin + i*(xmax-xmin)/xsize;
double y = ymin + j*(ymax-ymin)/ysize;
double z = zmin + k*(zmax-zmin)/zsize;
table[i+xsize*j+xsize*ysize*k] = sin(0.25*x)*cos(0.33*y)*(1+z);
}
}
}
forceField->addFunction("fn", new Continuous3DFunction(xsize, ysize, zsize, table, xmin, xmax, ymin, ymax, zmin, zmax));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (double x = xmin-0.15; x < xmax+0.2; x += 0.1) {
for (double y = ymin-0.15; y < ymax+0.2; y += 0.1) {
for (double z = zmin-0.15; z < zmax+0.2; z += 0.1) {
positions[1] = Vec3(x, 0, 0);
context.setParameter("a", y);
context.setParameter("b", z);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
double energy = 1;
double force = 0;
if (x >= xmin && x <= xmax && y >= ymin && y <= ymax && z >= zmin && z <= zmax) {
energy = sin(0.25*x)*cos(0.33*y)*(1.0+z)+1.0;
force = -0.25*cos(0.25*x)*cos(0.33*y)*(1.0+z);
}
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.05);
}
}
}
}
void testDiscrete1DFunction() {
ReferencePlatform platform;
System system;
......@@ -806,6 +865,7 @@ int main() {
testPeriodic();
testContinuous1DFunction();
testContinuous2DFunction();
testContinuous3DFunction();
testDiscrete1DFunction();
testDiscrete2DFunction();
testDiscrete3DFunction();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment