Commit fbfc55a5 authored by Charlles Abreu's avatar Charlles Abreu
Browse files

Elimination of ContinuousPeriodic1DFunction class

parent 3026c6db
......@@ -119,55 +119,6 @@ private:
bool periodic;
};
/**
* This is a TabulatedFunction that computes a continuous, periodic one dimensional function.
*/
class OPENMM_EXPORT ContinuousPeriodic1DFunction : public TabulatedFunction {
public:
/**
* Create a ContinuousPeriodic1DFunction f(x) based on a set of tabulated values.
*
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min
* and max. A periodic cubic spline is used to interpolate between the tabulated values.
* The function is assumed to be periodic with period L=max-min. The first and last
* elements must have the same value.
* @param min the value of x corresponding to the first element of values
* @param max the value of x corresponding to the last element of values
*/
ContinuousPeriodic1DFunction(const std::vector<double>& values, double min, double max);
/**
* Get the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min
* and max. A periodic cubic spline is used to interpolate between the tabulated values.
* The function is assumed to be periodic with period L=max-min. The first and last
* elements must have the same value.
* @param[out] min the value of x corresponding to the first element of values
* @param[out] max the value of x corresponding to the last element of values
*/
void getFunctionParameters(std::vector<double>& values, double& min, double& max) const;
/**
* Set the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x) at uniformly spaced values of x between min
* and max. A periodic cubic spline is used to interpolate between the tabulated values.
* The function is assumed to be periodic with period L=max-min. The first and last
* elements must have the same value.
* @param min the value of x corresponding to the first element of values
* @param max the value of x corresponding to the last element of values
*/
void setFunctionParameters(const std::vector<double>& values, double min, double max);
/**
* Create a deep copy of the tabulated function.
*
* @deprecated This will be removed in a future release.
*/
ContinuousPeriodic1DFunction* Copy() const;
private:
std::vector<double> values;
double min, max;
};
/**
* This is a TabulatedFunction that computes a continuous two dimensional function.
*/
......
......@@ -74,45 +74,6 @@ Continuous1DFunction* Continuous1DFunction::Copy() const {
return new Continuous1DFunction(new_vec, min, max);
}
ContinuousPeriodic1DFunction::ContinuousPeriodic1DFunction(const vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("ContinuousPeriodic1DFunction: max <= min for a tabulated function.");
int n = values.size();
if (n < 3)
throw OpenMMException("ContinuousPeriodic1DFunction: a periodic tabulated function must have at least three points");
if (values[0] != values[n-1])
throw OpenMMException("ContinuousPeriodic1DFunction: the first and last points must have the same value");
this->values = values;
this->min = min;
this->max = max;
}
void ContinuousPeriodic1DFunction::getFunctionParameters(vector<double>& values, double& min, double& max) const {
values = this->values;
min = this->min;
max = this->max;
}
void ContinuousPeriodic1DFunction::setFunctionParameters(const vector<double>& values, double min, double max) {
if (max <= min)
throw OpenMMException("ContinuousPeriodic1DFunction: max <= min for a tabulated function.");
int n = values.size();
if (n < 3)
throw OpenMMException("ContinuousPeriodic1DFunction: a periodic tabulated function must have at least three points");
if (values[0] != values[n-1])
throw OpenMMException("ContinuousPeriodic1DFunction: the first and last points must have the same value");
this->values = values;
this->min = min;
this->max = max;
}
ContinuousPeriodic1DFunction* ContinuousPeriodic1DFunction::Copy() const {
vector<double> new_vec(values.size());
for (size_t i = 0; i < values.size(); i++)
new_vec[i] = values[i];
return new ContinuousPeriodic1DFunction(new_vec, min, max);
}
Continuous2DFunction::Continuous2DFunction(int xsize, int ysize, const vector<double>& values, double xmin, double xmax, double ymin, double ymax) {
if (xsize < 2 || ysize < 2)
throw OpenMMException("Continuous2DFunction: must have at least two points along each axis");
......
......@@ -255,24 +255,6 @@ void ExpressionUtilities::processExpression(stringstream& out, const ExpressionT
}
out << "}\n";
}
else if (dynamic_cast<const ContinuousPeriodic1DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x - " << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0)
out << nodeNames[j] << suffix << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
else
out << nodeNames[j] << suffix << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
}
out << "}\n";
}
else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << suffix << ";\n";
out << "real y = " << getTempName(node.getChildren()[1], temps) << suffix << ";\n";
......@@ -758,28 +740,6 @@ vector<float> ExpressionUtilities::computeFunctionCoefficients(const TabulatedFu
width = 4;
return f;
}
if (dynamic_cast<const ContinuousPeriodic1DFunction*>(&function) != NULL) {
// Compute the spline coefficients.
const ContinuousPeriodic1DFunction& fn = dynamic_cast<const ContinuousPeriodic1DFunction&>(function);
vector<double> values;
double min, max;
fn.getFunctionParameters(values, min, max);
int numValues = values.size();
vector<double> x(numValues), derivs;
for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createPeriodicSpline(x, values, derivs);
vector<float> f(4*(numValues-1));
for (int i = 0; i < (int) values.size()-1; i++) {
f[4*i] = (float) values[i];
f[4*i+1] = (float) values[i+1];
f[4*i+2] = (float) (derivs[i]/6.0);
f[4*i+3] = (float) (derivs[i+1]/6.0);
}
width = 4;
return f;
}
if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL) {
// Compute the spline coefficients.
......@@ -885,16 +845,6 @@ vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vec
params[i].push_back((values.size()-1)/(max-min));
params[i].push_back(values.size()-2);
}
else if (dynamic_cast<const ContinuousPeriodic1DFunction*>(functions[i]) != NULL) {
const ContinuousPeriodic1DFunction& fn = dynamic_cast<const ContinuousPeriodic1DFunction&>(*functions[i]);
vector<double> values;
double min, max;
fn.getFunctionParameters(values, min, max);
params[i].push_back(min);
params[i].push_back(max);
params[i].push_back((values.size()-1)/(max-min));
params[i].push_back(values.size()-2);
}
else if (dynamic_cast<const Continuous2DFunction*>(functions[i]) != NULL) {
const Continuous2DFunction& fn = dynamic_cast<const Continuous2DFunction&>(*functions[i]);
vector<double> values;
......@@ -961,8 +911,6 @@ vector<vector<double> > ExpressionUtilities::computeFunctionParameters(const vec
Lepton::CustomFunction* ExpressionUtilities::getFunctionPlaceholder(const TabulatedFunction& function) {
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
return &fp1;
if (dynamic_cast<const ContinuousPeriodic1DFunction*>(&function) != NULL)
return &fp1;
if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL)
return &fp2;
if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL)
......
......@@ -63,23 +63,6 @@ private:
std::vector<double> x, values, derivs;
};
/**
* This class adapts a ContinuousPeriodic1DFunction into a Lepton::CustomFunction.
*/
class OPENMM_EXPORT ReferenceContinuousPeriodic1DFunction : public Lepton::CustomFunction {
public:
ReferenceContinuousPeriodic1DFunction(const ContinuousPeriodic1DFunction& function);
int getNumArguments() const;
double evaluate(const double* arguments) const;
double evaluateDerivative(const double* arguments, const int* derivOrder) const;
CustomFunction* clone() const;
private:
ReferenceContinuousPeriodic1DFunction(const ReferenceContinuousPeriodic1DFunction& other);
const ContinuousPeriodic1DFunction& function;
double min, max;
std::vector<double> x, values, derivs;
};
/**
* This class adapts a Continuous2DFunction into a Lepton::CustomFunction.
*/
......
......@@ -59,8 +59,6 @@ extern "C" OPENMM_EXPORT CustomFunction* createReferenceTabulatedFunction(const
CustomFunction* fn;
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
fn = new ReferenceContinuous1DFunction(dynamic_cast<const Continuous1DFunction&>(function));
else if (dynamic_cast<const ContinuousPeriodic1DFunction*>(&function) != NULL)
fn = new ReferenceContinuousPeriodic1DFunction(dynamic_cast<const ContinuousPeriodic1DFunction&>(function));
else if (dynamic_cast<const Continuous2DFunction*>(&function) != NULL)
fn = new ReferenceContinuous2DFunction(dynamic_cast<const Continuous2DFunction&>(function));
else if (dynamic_cast<const Continuous3DFunction*>(&function) != NULL)
......@@ -119,44 +117,6 @@ CustomFunction* ReferenceContinuous1DFunction::clone() const {
return new ReferenceContinuous1DFunction(*this);
}
ReferenceContinuousPeriodic1DFunction::ReferenceContinuousPeriodic1DFunction(const ContinuousPeriodic1DFunction& function) : function(function) {
function.getFunctionParameters(values, min, max);
int numValues = values.size();
x.resize(numValues);
for (int i = 0; i < numValues; i++)
x[i] = min+i*(max-min)/(numValues-1);
SplineFitter::createPeriodicSpline(x, values, derivs);
}
ReferenceContinuousPeriodic1DFunction::ReferenceContinuousPeriodic1DFunction(const ReferenceContinuousPeriodic1DFunction& other) : function(other.function) {
function.getFunctionParameters(values, min, max);
x = other.x;
values = other.values;
derivs = other.derivs;
}
int ReferenceContinuousPeriodic1DFunction::getNumArguments() const {
return 1;
}
double ReferenceContinuousPeriodic1DFunction::evaluate(const double* arguments) const {
double t = arguments[0];
if (t < min || t > max)
return 0.0;
return SplineFitter::evaluateSpline(x, values, derivs, t);
}
double ReferenceContinuousPeriodic1DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
double t = arguments[0];
if (t < min || t > max)
return 0.0;
return SplineFitter::evaluateSplineDerivative(x, values, derivs, t);
}
CustomFunction* ReferenceContinuousPeriodic1DFunction::clone() const {
return new ReferenceContinuousPeriodic1DFunction(*this);
}
ReferenceContinuous2DFunction::ReferenceContinuous2DFunction(const Continuous2DFunction& function) : function(function) {
function.getFunctionParameters(xsize, ysize, values, xmin, xmax, ymin, ymax);
x.resize(xsize);
......
......@@ -48,17 +48,6 @@ public:
void* deserialize(const SerializationNode& node) const;
};
/**
* This is a proxy for serializing ContinuousPeriodic1DFunction objects.
*/
class OPENMM_EXPORT ContinuousPeriodic1DFunctionProxy : public SerializationProxy {
public:
ContinuousPeriodic1DFunctionProxy();
void serialize(const void* object, SerializationNode& node) const;
void* deserialize(const SerializationNode& node) const;
};
/**
* This is a proxy for serializing Continuous2DFunction objects.
*/
......
......@@ -66,32 +66,6 @@ void* Continuous1DFunctionProxy::deserialize(const SerializationNode& node) cons
return new Continuous1DFunction(values, node.getDoubleProperty("min"), node.getDoubleProperty("max"), node.getBoolProperty("periodic"));
}
ContinuousPeriodic1DFunctionProxy::ContinuousPeriodic1DFunctionProxy() : SerializationProxy("ContinuousPeriodic1DFunction") {
}
void ContinuousPeriodic1DFunctionProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1);
const ContinuousPeriodic1DFunction& function = *reinterpret_cast<const ContinuousPeriodic1DFunction*>(object);
double min, max;
vector<double> values;
function.getFunctionParameters(values, min, max);
node.setDoubleProperty("min", min);
node.setDoubleProperty("max", max);
SerializationNode& valuesNode = node.createChildNode("Values");
for (auto v : values)
valuesNode.createChildNode("Value").setDoubleProperty("v", v);
}
void* ContinuousPeriodic1DFunctionProxy::deserialize(const SerializationNode& node) const {
if (node.getIntProperty("version") != 1)
throw OpenMMException("Unsupported version number");
const SerializationNode& valuesNode = node.getChildNode("Values");
vector<double> values;
for (auto& child : valuesNode.getChildren())
values.push_back(child.getDoubleProperty("v"));
return new ContinuousPeriodic1DFunction(values, node.getDoubleProperty("min"), node.getDoubleProperty("max"));
}
Continuous2DFunctionProxy::Continuous2DFunctionProxy() : SerializationProxy("Continuous2DFunction") {
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment