Commit eacb6882 authored by Charlles Abreu's avatar Charlles Abreu
Browse files

Periodic 1D tabulated function via optional argument

parent ab1a3891
...@@ -78,8 +78,9 @@ public: ...@@ -78,8 +78,9 @@ public:
* The function is assumed to be zero for x < min or x > max. * The function is assumed to be zero for x < min or x > max.
* @param min the value of x corresponding to the first element of values * @param min the value of x corresponding to the first element of values
* @param max the value of x corresponding to the last element of values * @param max the value of x corresponding to the last element of values
* @param periodic whether the function is periodic with period L = max - min
*/ */
Continuous1DFunction(const std::vector<double>& values, double min, double max); Continuous1DFunction(const std::vector<double>& values, double min, double max, bool periodic=false);
/** /**
* Get the parameters for the tabulated function. * Get the parameters for the tabulated function.
* *
...@@ -90,6 +91,12 @@ public: ...@@ -90,6 +91,12 @@ public:
* @param[out] max the value of x corresponding to the last element of values * @param[out] max the value of x corresponding to the last element of values
*/ */
void getFunctionParameters(std::vector<double>& values, double& min, double& max) const; void getFunctionParameters(std::vector<double>& values, double& min, double& max) const;
/**
* Get the periodicity status of the tabulated function.
*
* @param periodic whether the function is periodic
*/
void getPeriodicityStatus(bool& periodic) const;
/** /**
* Set the parameters for the tabulated function. * Set the parameters for the tabulated function.
* *
...@@ -109,6 +116,7 @@ public: ...@@ -109,6 +116,7 @@ public:
private: private:
std::vector<double> values; std::vector<double> values;
double min, max; double min, max;
bool periodic;
}; };
/** /**
......
...@@ -35,14 +35,9 @@ ...@@ -35,14 +35,9 @@
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max) { Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max, bool periodic) {
if (max <= min) this->periodic = periodic;
throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function."); setFunctionParameters(values, min, max);
if (values.size() < 2)
throw OpenMMException("Continuous1DFunction: a tabulated function must have at least two points");
this->values = values;
this->min = min;
this->max = max;
} }
void Continuous1DFunction::getFunctionParameters(vector<double>& values, double& min, double& max) const { void Continuous1DFunction::getFunctionParameters(vector<double>& values, double& min, double& max) const {
...@@ -51,11 +46,22 @@ void Continuous1DFunction::getFunctionParameters(vector<double>& values, double& ...@@ -51,11 +46,22 @@ void Continuous1DFunction::getFunctionParameters(vector<double>& values, double&
max = this->max; max = this->max;
} }
void Continuous1DFunction::getPeriodicityStatus(bool& periodic) const {
periodic = this->periodic;
}
void Continuous1DFunction::setFunctionParameters(const vector<double>& values, double min, double max) { void Continuous1DFunction::setFunctionParameters(const vector<double>& values, double min, double max) {
if (max <= min) if (max <= min)
throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function."); throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function.");
if (values.size() < 2) int n = values.size();
throw OpenMMException("Continuous1DFunction: a tabulated function must have at least two points"); if (periodic) {
if (n < 3)
throw OpenMMException("Continuous1DFunction: a periodic tabulated function must have at least three points");
if (values[0] != values[n-1])
throw OpenMMException("Continuous1DFunction: with periodic=true, the first and last points must have the same value");
}
else if (n < 2)
throw OpenMMException("Continuous1DFunction: a non-periodic tabulated function must have at least two points");
this->values = values; this->values = values;
this->min = min; this->min = min;
this->max = max; this->max = max;
......
...@@ -738,11 +738,16 @@ vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFu ...@@ -738,11 +738,16 @@ vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFu
vector<double> values; vector<double> values;
double min, max; double min, max;
fn.getFunctionParameters(values, min, max); fn.getFunctionParameters(values, min, max);
bool periodic;
fn.getPeriodicityStatus(periodic);
int numValues = values.size(); int numValues = values.size();
vector<double> x(numValues), derivs; vector<double> x(numValues), derivs;
for (int i = 0; i < numValues; i++) for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1); x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createNaturalSpline(x, values, derivs); if (periodic)
SplineFitter::createPeriodicSpline(x, values, derivs);
else
SplineFitter::createNaturalSpline(x, values, derivs);
vector<float> f(4*(numValues-1)); vector<float> f(4*(numValues-1));
for (int i = 0; i < (int) values.size()-1; i++) { for (int i = 0; i < (int) values.size()-1; i++) {
f[4*i] = (float) values[i]; f[4*i] = (float) values[i];
......
...@@ -59,6 +59,7 @@ private: ...@@ -59,6 +59,7 @@ private:
ReferenceContinuous1DFunction(const ReferenceContinuous1DFunction& other); ReferenceContinuous1DFunction(const ReferenceContinuous1DFunction& other);
const Continuous1DFunction& function; const Continuous1DFunction& function;
double min, max; double min, max;
bool periodic;
std::vector<double> x, values, derivs; std::vector<double> x, values, derivs;
}; };
......
...@@ -78,15 +78,20 @@ extern "C" OPENMM_EXPORT CustomFunction* createReferenceTabulatedFunction(const ...@@ -78,15 +78,20 @@ extern "C" OPENMM_EXPORT CustomFunction* createReferenceTabulatedFunction(const
ReferenceContinuous1DFunction::ReferenceContinuous1DFunction(const Continuous1DFunction& function) : function(function) { ReferenceContinuous1DFunction::ReferenceContinuous1DFunction(const Continuous1DFunction& function) : function(function) {
function.getFunctionParameters(values, min, max); function.getFunctionParameters(values, min, max);
function.getPeriodicityStatus(periodic);
int numValues = values.size(); int numValues = values.size();
x.resize(numValues); x.resize(numValues);
for (int i = 0; i < numValues; i++) for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1); x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createNaturalSpline(x, values, derivs); if (periodic)
SplineFitter::createPeriodicSpline(x, values, derivs);
else
SplineFitter::createNaturalSpline(x, values, derivs);
} }
ReferenceContinuous1DFunction::ReferenceContinuous1DFunction(const ReferenceContinuous1DFunction& other) : function(other.function) { ReferenceContinuous1DFunction::ReferenceContinuous1DFunction(const ReferenceContinuous1DFunction& other) : function(other.function) {
function.getFunctionParameters(values, min, max); function.getFunctionParameters(values, min, max);
function.getPeriodicityStatus(periodic);
x = other.x; x = other.x;
values = other.values; values = other.values;
derivs = other.derivs; derivs = other.derivs;
......
...@@ -51,6 +51,9 @@ void Continuous1DFunctionProxy::serialize(const void* object, SerializationNode& ...@@ -51,6 +51,9 @@ void Continuous1DFunctionProxy::serialize(const void* object, SerializationNode&
SerializationNode& valuesNode = node.createChildNode("Values"); SerializationNode& valuesNode = node.createChildNode("Values");
for (auto v : values) for (auto v : values)
valuesNode.createChildNode("Value").setDoubleProperty("v", v); valuesNode.createChildNode("Value").setDoubleProperty("v", v);
bool periodic;
function.getPeriodicityStatus(periodic);
node.setBoolProperty("periodic", periodic);
} }
void* Continuous1DFunctionProxy::deserialize(const SerializationNode& node) const { void* Continuous1DFunctionProxy::deserialize(const SerializationNode& node) const {
...@@ -60,7 +63,7 @@ void* Continuous1DFunctionProxy::deserialize(const SerializationNode& node) cons ...@@ -60,7 +63,7 @@ void* Continuous1DFunctionProxy::deserialize(const SerializationNode& node) cons
vector<double> values; vector<double> values;
for (auto& child : valuesNode.getChildren()) for (auto& child : valuesNode.getChildren())
values.push_back(child.getDoubleProperty("v")); values.push_back(child.getDoubleProperty("v"));
return new Continuous1DFunction(values, node.getDoubleProperty("min"), node.getDoubleProperty("max")); return new Continuous1DFunction(values, node.getDoubleProperty("min"), node.getDoubleProperty("max"), node.getBoolProperty("periodic"));
} }
ContinuousPeriodic1DFunctionProxy::ContinuousPeriodic1DFunctionProxy() : SerializationProxy("ContinuousPeriodic1DFunction") { ContinuousPeriodic1DFunctionProxy::ContinuousPeriodic1DFunctionProxy() : SerializationProxy("ContinuousPeriodic1DFunction") {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment