Commit eacb6882 authored by Charlles Abreu's avatar Charlles Abreu
Browse files

Periodic 1D tabulated function via optional argument

parent ab1a3891
......@@ -78,8 +78,9 @@ public:
* The function is assumed to be zero for x < min or x > max.
* @param min the value of x corresponding to the first element of values
* @param max the value of x corresponding to the last element of values
* @param periodic whether the function is periodic with period L = max - min
*/
Continuous1DFunction(const std::vector<double>& values, double min, double max);
Continuous1DFunction(const std::vector<double>& values, double min, double max, bool periodic=false);
/**
* Get the parameters for the tabulated function.
*
......@@ -90,6 +91,12 @@ public:
* @param[out] max the value of x corresponding to the last element of values
*/
void getFunctionParameters(std::vector<double>& values, double& min, double& max) const;
/**
* Get the periodicity status of the tabulated function.
*
* @param periodic whether the function is periodic
*/
void getPeriodicityStatus(bool& periodic) const;
/**
* Set the parameters for the tabulated function.
*
......@@ -109,6 +116,7 @@ public:
private:
std::vector<double> values;
double min, max;
bool periodic;
};
/**
......
......@@ -35,14 +35,9 @@
using namespace OpenMM;
using namespace std;
Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("Continuous1DFunction: a tabulated function must have at least two points");
this->values = values;
this->min = min;
this->max = max;
Continuous1DFunction::Continuous1DFunction(const vector<double>& values, double min, double max, bool periodic) {
this->periodic = periodic;
setFunctionParameters(values, min, max);
}
void Continuous1DFunction::getFunctionParameters(vector<double>& values, double& min, double& max) const {
......@@ -51,11 +46,22 @@ void Continuous1DFunction::getFunctionParameters(vector<double>& values, double&
max = this->max;
}
void Continuous1DFunction::getPeriodicityStatus(bool& periodic) const {
periodic = this->periodic;
}
void Continuous1DFunction::setFunctionParameters(const vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("Continuous1DFunction: max <= min for a tabulated function.");
if (values.size() < 2)
throw OpenMMException("Continuous1DFunction: a tabulated function must have at least two points");
int n = values.size();
if (periodic) {
if (n < 3)
throw OpenMMException("Continuous1DFunction: a periodic tabulated function must have at least three points");
if (values[0] != values[n-1])
throw OpenMMException("Continuous1DFunction: with periodic=true, the first and last points must have the same value");
}
else if (n < 2)
throw OpenMMException("Continuous1DFunction: a non-periodic tabulated function must have at least two points");
this->values = values;
this->min = min;
this->max = max;
......
......@@ -738,10 +738,15 @@ vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFu
vector<double> values;
double min, max;
fn.getFunctionParameters(values, min, max);
bool periodic;
fn.getPeriodicityStatus(periodic);
int numValues = values.size();
vector<double> x(numValues), derivs;
for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1);
if (periodic)
SplineFitter::createPeriodicSpline(x, values, derivs);
else
SplineFitter::createNaturalSpline(x, values, derivs);
vector<float> f(4*(numValues-1));
for (int i = 0; i < (int) values.size()-1; i++) {
......
......@@ -59,6 +59,7 @@ private:
ReferenceContinuous1DFunction(const ReferenceContinuous1DFunction& other);
const Continuous1DFunction& function;
double min, max;
bool periodic;
std::vector<double> x, values, derivs;
};
......
......@@ -78,15 +78,20 @@ extern "C" OPENMM_EXPORT CustomFunction* createReferenceTabulatedFunction(const
ReferenceContinuous1DFunction::ReferenceContinuous1DFunction(const Continuous1DFunction& function) : function(function) {
function.getFunctionParameters(values, min, max);
function.getPeriodicityStatus(periodic);
int numValues = values.size();
x.resize(numValues);
for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1);
if (periodic)
SplineFitter::createPeriodicSpline(x, values, derivs);
else
SplineFitter::createNaturalSpline(x, values, derivs);
}
ReferenceContinuous1DFunction::ReferenceContinuous1DFunction(const ReferenceContinuous1DFunction& other) : function(other.function) {
function.getFunctionParameters(values, min, max);
function.getPeriodicityStatus(periodic);
x = other.x;
values = other.values;
derivs = other.derivs;
......
......@@ -51,6 +51,9 @@ void Continuous1DFunctionProxy::serialize(const void* object, SerializationNode&
SerializationNode& valuesNode = node.createChildNode("Values");
for (auto v : values)
valuesNode.createChildNode("Value").setDoubleProperty("v", v);
bool periodic;
function.getPeriodicityStatus(periodic);
node.setBoolProperty("periodic", periodic);
}
void* Continuous1DFunctionProxy::deserialize(const SerializationNode& node) const {
......@@ -60,7 +63,7 @@ void* Continuous1DFunctionProxy::deserialize(const SerializationNode& node) cons
vector<double> values;
for (auto& child : valuesNode.getChildren())
values.push_back(child.getDoubleProperty("v"));
return new Continuous1DFunction(values, node.getDoubleProperty("min"), node.getDoubleProperty("max"));
return new Continuous1DFunction(values, node.getDoubleProperty("min"), node.getDoubleProperty("max"), node.getBoolProperty("periodic"));
}
ContinuousPeriodic1DFunctionProxy::ContinuousPeriodic1DFunctionProxy() : SerializationProxy("ContinuousPeriodic1DFunction") {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment