Commit ce15d729 authored by Charlles Abreu's avatar Charlles Abreu
Browse files

Periodic 2D and 3D tabulated functions included in API

parent a8148ba2
...@@ -63,6 +63,19 @@ public: ...@@ -63,6 +63,19 @@ public:
* @deprecated This will be removed in a future release. * @deprecated This will be removed in a future release.
*/ */
virtual TabulatedFunction* Copy() const = 0; virtual TabulatedFunction* Copy() const = 0;
/**
* Get the periodicity status of the tabulated function.
*
*/
bool getPeriodic() const;
/**
* Set the periodicity status for the tabulated function.
*
* @param periodic whether the function is periodic with period L = max - min
*/
void setPeriodic(bool periodic);
protected:
bool periodic;
}; };
/** /**
...@@ -78,7 +91,7 @@ public: ...@@ -78,7 +91,7 @@ public:
* The function is assumed to be zero for x < min or x > max. * The function is assumed to be zero for x < min or x > max.
* @param min the value of x corresponding to the first element of values * @param min the value of x corresponding to the first element of values
* @param max the value of x corresponding to the last element of values * @param max the value of x corresponding to the last element of values
* @param periodic whether the function is periodic with period L = max - min * @param periodic whether the interpolated function is periodic
*/ */
Continuous1DFunction(const std::vector<double>& values, double min, double max, bool periodic=false); Continuous1DFunction(const std::vector<double>& values, double min, double max, bool periodic=false);
/** /**
...@@ -91,11 +104,6 @@ public: ...@@ -91,11 +104,6 @@ public:
* @param[out] max the value of x corresponding to the last element of values * @param[out] max the value of x corresponding to the last element of values
*/ */
void getFunctionParameters(std::vector<double>& values, double& min, double& max) const; void getFunctionParameters(std::vector<double>& values, double& min, double& max) const;
/**
* Get the periodicity status of the tabulated function.
*
*/
bool getPeriodic() const;
/** /**
* Set the parameters for the tabulated function. * Set the parameters for the tabulated function.
* *
...@@ -106,12 +114,6 @@ public: ...@@ -106,12 +114,6 @@ public:
* @param max the value of x corresponding to the last element of values * @param max the value of x corresponding to the last element of values
*/ */
void setFunctionParameters(const std::vector<double>& values, double min, double max); void setFunctionParameters(const std::vector<double>& values, double min, double max);
/**
* Set the periodicity status for the tabulated function.
*
* @param periodic whether the function is periodic with period L = max - min
*/
void setPeriodic(bool periodic);
/** /**
* Create a deep copy of the tabulated function. * Create a deep copy of the tabulated function.
* *
...@@ -121,7 +123,6 @@ public: ...@@ -121,7 +123,6 @@ public:
private: private:
std::vector<double> values; std::vector<double> values;
double min, max; double min, max;
bool periodic;
}; };
/** /**
...@@ -142,8 +143,9 @@ public: ...@@ -142,8 +143,9 @@ public:
* @param xmax the value of x corresponding to the last element of values * @param xmax the value of x corresponding to the last element of values
* @param ymin the value of y corresponding to the first element of values * @param ymin the value of y corresponding to the first element of values
* @param ymax the value of y corresponding to the last element of values * @param ymax the value of y corresponding to the last element of values
* @param periodic whether the interpolated function is periodic
*/ */
Continuous2DFunction(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax); Continuous2DFunction(int xsize, int ysize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, bool periodic=false);
/** /**
* Get the parameters for the tabulated function. * Get the parameters for the tabulated function.
* *
...@@ -209,8 +211,9 @@ public: ...@@ -209,8 +211,9 @@ public:
* @param ymax the value of y corresponding to the last element of values * @param ymax the value of y corresponding to the last element of values
* @param zmin the value of z corresponding to the first element of values * @param zmin the value of z corresponding to the first element of values
* @param zmax the value of z corresponding to the last element of values * @param zmax the value of z corresponding to the last element of values
* @param periodic whether the interpolated function is periodic
*/ */
Continuous3DFunction(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax); Continuous3DFunction(int xsize, int ysize, int zsize, const std::vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax, bool periodic=false);
/** /**
* Get the parameters for the tabulated function. * Get the parameters for the tabulated function.
* *
......
...@@ -35,6 +35,14 @@ ...@@ -35,6 +35,14 @@
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
bool TabulatedFunction::getPeriodic() const {
return periodic;
}
void TabulatedFunction::setPeriodic(bool periodic) {
this->periodic = periodic;
}
Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max, bool periodic) { Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max, bool periodic) {
this->periodic = periodic; this->periodic = periodic;
setFunctionParameters(values, min, max); setFunctionParameters(values, min, max);
...@@ -46,10 +54,6 @@ void Continuous1DFunction::getFunctionParameters(vector<double>& values, double& ...@@ -46,10 +54,6 @@ void Continuous1DFunction::getFunctionParameters(vector<double>& values, double&
max = this->max; max = this->max;
} }
bool Continuous1DFunction::getPeriodic() const {
return periodic;
}
void Continuous1DFunction::setFunctionParameters(const vector<double>& values, double min, double max) { void Continuous1DFunction::setFunctionParameters(const vector<double>& values, double min, double max) {
if (max <= min) if (max <= min)
throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function."); throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function.");
...@@ -67,10 +71,6 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d ...@@ -67,10 +71,6 @@ void Continuous1DFunction::setFunctionParameters(const vector<double>& values, d
this->max = max; this->max = max;
} }
void Continuous1DFunction::setPeriodic(bool periodic) {
this->periodic = periodic;
}
Continuous1DFunction* Continuous1DFunction::Copy() const { Continuous1DFunction* Continuous1DFunction::Copy() const {
vector<double> new_vec(values.size()); vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++) for (size_t i = 0; i < values.size(); i++)
...@@ -78,22 +78,9 @@ Continuous1DFunction* Continuous1DFunction::Copy() const { ...@@ -78,22 +78,9 @@ Continuous1DFunction* Continuous1DFunction::Copy() const {
return new Continuous1DFunction(new_vec, min, max); return new Continuous1DFunction(new_vec, min, max);
} }
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) { Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, bool periodic) {
if (xsize < 2 || ysize < 2) this->periodic = periodic;
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis"); setFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
if (values.size() != xsize*ysize)
throw OpenMMException("Continuous2DFunction: incorrect number of values");
if (xmax <= xmin)
throw OpenMMException("Continuous2DFunction: xmax <= xmin for a tabulated function.");
if (ymax <= ymin)
throw OpenMMException("Continuous2DFunction: ymax <= ymin for a tabulated function.");
this->values = values;
this->xsize = xsize;
this->ysize = ysize;
this->xmin = xmin;
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
} }
void Continuous2DFunction::getFunctionParameters(int& xsize, int& ysize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax) const { void Continuous2DFunction::getFunctionParameters(int& xsize, int& ysize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax) const {
...@@ -107,7 +94,13 @@ void Continuous2DFunction::getFunctionParameters(int& xsize, int& ysize, vector< ...@@ -107,7 +94,13 @@ void Continuous2DFunction::getFunctionParameters(int& xsize, int& ysize, vector<
} }
void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) { void Continuous2DFunction::setFunctionParameters(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) {
if (xsize < 2 || ysize < 2) if (periodic) {
if (xsize < 3 || ysize < 3)
throw OpenMMException("Continuous2DFunction: must have at least three points along each axis if periodic");
// Note: value-matching at boundary is eventually checked at 2D-spline creation time.
}
else if (xsize < 2 || ysize < 2)
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis"); throw OpenMMException("Continuous2DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize) if (values.size() != xsize*ysize)
throw OpenMMException("Continuous2DFunction: incorrect number of values"); throw OpenMMException("Continuous2DFunction: incorrect number of values");
...@@ -131,27 +124,9 @@ Continuous2DFunction* Continuous2DFunction::Copy() const { ...@@ -131,27 +124,9 @@ Continuous2DFunction* Continuous2DFunction::Copy() const {
return new Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax); return new Continuous2DFunction(xsize, ysize, new_vec, xmin, xmax, ymin, ymax);
} }
Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) { Continuous3DFunction::Continuous3DFunction(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax, bool periodic) {
if (xsize < 2 || ysize < 2 || zsize < 2) this->periodic = periodic;
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis"); setFunctionParameters(xsize, ysize, zsize, values, xmin, xmax, ymin, ymax, zmin, zmax);
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Continuous3DFunction: incorrect number of values");
if (xmax <= xmin)
throw OpenMMException("Continuous3DFunction: xmax <= xmin for a tabulated function.");
if (ymax <= ymin)
throw OpenMMException("Continuous3DFunction: ymax <= ymin for a tabulated function.");
if (zmax <= zmin)
throw OpenMMException("Continuous3DFunction: zmax <= zmin for a tabulated function.");
this->values = values;
this->xsize = xsize;
this->ysize = ysize;
this->zsize = zsize;
this->xmin = xmin;
this->xmax = xmax;
this->ymin = ymin;
this->ymax = ymax;
this->zmin = zmin;
this->zmax = zmax;
} }
void Continuous3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zsize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax, double& zmin, double& zmax) const { void Continuous3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zsize, vector<double>& values, double& xmin, double& xmax, double& ymin, double& ymax, double& zmin, double& zmax) const {
...@@ -168,7 +143,13 @@ void Continuous3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zs ...@@ -168,7 +143,13 @@ void Continuous3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zs
} }
void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) { void Continuous3DFunction::setFunctionParameters(int xsize, int ysize, int zsize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax, double zmin, double zmax) {
if (xsize < 2 || ysize < 2 || zsize < 2) if (periodic) {
if (xsize < 3 || ysize < 3 || zsize < 3)
throw OpenMMException("Continuous3DFunction: must have at least three points along each axis if periodic");
// Note: value-matching at boundary is eventually checked at 3D-spline creation time.
}
else if (xsize < 2 || ysize < 2 || zsize < 2)
throw OpenMMException("Continuous3DFunction: must have at least two points along each axis"); throw OpenMMException("Continuous3DFunction: must have at least two points along each axis");
if (values.size() != xsize*ysize*zsize) if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Continuous3DFunction: incorrect number of values"); throw OpenMMException("Continuous3DFunction: incorrect number of values");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment