Commit ce15d729 authored by Charlles Abreu's avatar Charlles Abreu
Browse files

Periodic 2D and 3D tabulated functions included in API

parent a8148ba2
......@@ -63,6 +63,19 @@ public:
* @deprecated This will be removed in a future release.
*/
virtual TabulatedFunction* Copy() const = 0;
/**
* Get the periodicity status of the tabulated function.
*
*/
bool getPeriodic() const;
/**
* Set the periodicity status for the tabulated function.
*
* @param periodic whether the function is periodic with period L = max - min
*/
void setPeriodic(bool periodic);
protected:
bool periodic;
};
/**
......@@ -78,7 +91,7 @@ public:
* The function is assumed to be zero for x < min or x > max.
* @param min the value of x corresponding to the first element of values
* @param max the value of x corresponding to the last element of values
* @param periodic whether the function is periodic with period L = max - min
* @param periodic whether the interpolated function is periodic
*/
Continuous1DFunction(const std::vector<double>& values, double min, double max, bool periodic=false);
/**
......@@ -91,11 +104,6 @@ public:
* @param[out] max the value of x corresponding to the last element of values
*/
void getFunctionParameters(std::vector<double>& values, double& min, double& max) const;
/**
* Get the periodicity status of the tabulated function.
*
*/
bool getPeriodic() const;
/**
* Set the parameters for the tabulated function.
*
......@@ -106,12 +114,6 @@ public:
* @param max the value of x corresponding to the last element of values
*/
void setFunctionParameters(const std::vector<double>& values, double min, double max);
/**
* Set the periodicity status for the tabulated function.
*
* @param periodic whether the function is periodic with period L = max - min
*/
void setPeriodic(bool periodic);
/**
* Create a deep copy of the tabulated function.
*
......@@ -121,7 +123,6 @@ public:
private:
std::vector<double> values;
double min, max;
bool periodic;
};
/**
......@@ -142,8 +143,9 @@ public:
* @param xmax the value of x corresponding to the last element of values
* @param ymin the value of y corresponding to the first element of values
* @param ymax the value of y corresponding to the last element of values
* @param periodic whether the interpolated function is periodic
*/
Continuous2DFunction(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax);
Continuous2DFunction(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, bool periodic=false);
/**
* Get the parameters for the tabulated function.
*
......@@ -209,8 +211,9 @@ public:
* @param ymax the value of y corresponding to the last element of values
* @param zmin the value of z corresponding to the first element of values
* @param zmax the value of z corresponding to the last element of values
* @param periodic whether the interpolated function is periodic
*/
Continuous3DFunction(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax);
Continuous3DFunction(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax, bool periodic=false);
/**
* Get the parameters for the tabulated function.
*
......
......@@ -35,6 +35,14 @@
using namespace OpenMM;
using namespace std;
bool TabulatedFunction::getPeriodic() const {
return periodic;
}
void TabulatedFunction::setPeriodic(bool periodic) {
this->periodic = periodic;
}
Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max, bool periodic) {
this->periodic = periodic;
setFunctionParameters(values, min, max);
......@@ -46,10 +54,6 @@ void Continuous1DFunction::getFunctionParameters(vector<double>& values, double&
max = this->max;
}
bool Continuous1DFunction::getPeriodic() const {
return periodic;
}
void Continuous1DFunction::setFunctionParameters(const vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function.");
......@@ -67,10 +71,6 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d
this->max = max;
}
void Continuous1DFunction::setPeriodic(bool periodic) {
this->periodic = periodic;
}
Continuous1DFunction* Continuous1DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
......@@ -78,22 +78,9 @@ Continuous1DFunction* Continuous1DFunction::Copy() const {
return new Continuous1DFunction(new_vec, min, max);
}
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) {
if (xsize < 2 || ysize < 2)
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize)
throw OpenMMException("Continuous2DFunction: incorrect number of values");
if (xmax <= xmin)
throw OpenMMException("Continuous2DFunction: xmax <= xmin for a tabulated function.");
if (ymax <= ymin)
throw OpenMMException("Continuous2DFunction: ymax <= ymin for a tabulated function.");
this->values = values;
this->xsize = xsize;
this->ysize = ysize;
this->xmin = xmin;
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, bool periodic) {
this->periodic = periodic;
setFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
}
void Continuous2DFunction::getFunctionParameters(int& xsize, int& ysize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax) const {
......@@ -107,7 +94,13 @@ void Continuous2DFunction::getFunctionParameters(int& xsize, int& ysize, vector<
}
void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) {
if (xsize < 2 || ysize < 2)
if (periodic) {
if (xsize < 3 || ysize < 3)
throw OpenMMException("Continuous2DFunction: must have at least three points along each axis if periodic");
// Note: value-matching at boundary is eventually checked at 2D-spline creation time.
}
else if (xsize < 2 || ysize < 2)
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize)
throw OpenMMException("Continuous2DFunction: incorrect number of values");
......@@ -131,27 +124,9 @@ Continuous2DFunction* Continuous2DFunction::Copy() const {
return new Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax);
}
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) {
if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Continuous3DFunction: incorrect number of values");
if (xmax <= xmin)
throw OpenMMException("Continuous3DFunction: xmax <= xmin for a tabulated function.");
if (ymax <= ymin)
throw OpenMMException("Continuous3DFunction: ymax <= ymin for a tabulated function.");
if (zmax <= zmin)
throw OpenMMException("Continuous3DFunction: zmax <= zmin for a tabulated function.");
this->values = values;
this->xsize = xsize;
this->ysize = ysize;
this->zsize = zsize;
this->xmin = xmin;
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
this->zmin = zmin;
this->zmax = zmax;
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax, bool periodic) {
this->periodic = periodic;
setFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
}
void Continuous3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zsize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax, double& zmin, double& zmax) const {
......@@ -168,7 +143,13 @@ void Continuous3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zs
}
void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) {
if (xsize < 2 || ysize < 2 || zsize < 2)
if (periodic) {
if (xsize < 3 || ysize < 3 || zsize < 3)
throw OpenMMException("Continuous3DFunction: must have at least three points along each axis if periodic");
// Note: value-matching at boundary is eventually checked at 3D-spline creation time.
}
else if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Continuous3DFunction: incorrect number of values");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment