Commit a773952e authored by peastman's avatar peastman
Browse files

Created Discrete2DFunction and Discrete3DFunction

parent 56e36449
...@@ -102,35 +102,115 @@ private: ...@@ -102,35 +102,115 @@ private:
}; };
/** /**
* This is a TabulatedFunction that computes a discrete one dimensional function. * This is a TabulatedFunction that computes a discrete one dimensional function f(x).
* To evaluate it, x is rounded to the nearest integer and the table element with that
* index is returned. If the index is outside the range [0, size), the result is undefined.
*/ */
class OPENMM_EXPORT Discrete1DFunction : public TabulatedFunction { class OPENMM_EXPORT Discrete1DFunction : public TabulatedFunction {
public: public:
/** /**
* Create a Discrete1DFunction f(x) based on a set of tabulated values. * Create a Discrete1DFunction f(x) based on a set of tabulated values.
* *
* @param values the tabulated values of the function f(x). The function is only defined * @param values the tabulated values of the function f(x)
* for integer values of x in the range [0, values.size()].
*/ */
Discrete1DFunction(const std::vector<double>& values); Discrete1DFunction(const std::vector<double>& values);
/** /**
* Get the parameters for the tabulated function. * Get the parameters for the tabulated function.
* *
* @param values the tabulated values of the function f(x). The function is only defined * @param values the tabulated values of the function f(x)
* for integer values of x in the range [0, values.size()].
*/ */
void getFunctionParameters(std::vector<double>& values) const; void getFunctionParameters(std::vector<double>& values) const;
/** /**
* Set the parameters for the tabulated function. * Set the parameters for the tabulated function.
* *
* @param values the tabulated values of the function f(x). The function is only defined * @param values the tabulated values of the function f(x)
* for integer values of x in the range [0, values.size()].
*/ */
void setFunctionParameters(const std::vector<double>& values); void setFunctionParameters(const std::vector<double>& values);
private: private:
std::vector<double> values; std::vector<double> values;
}; };
/**
* This is a TabulatedFunction that computes a discrete two dimensional function f(x,y).
* To evaluate it, x and y are each rounded to the nearest integer and the table element with those
* indices is returned. If either index is outside the range [0, size), the result is undefined.
*/
class OPENMM_EXPORT Discrete2DFunction : public TabulatedFunction {
public:
/**
* Create a Discrete2DFunction f(x,y) based on a set of tabulated values.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param values the tabulated values of the function f(x,y), ordered so that
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/
Discrete2DFunction(int xsize, int ysize, const std::vector<double>& values);
/**
* Get the parameters for the tabulated function.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param values the tabulated values of the function f(x,y), ordered so that
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/
void getFunctionParameters(int& xsize, int& ysize, std::vector<double>& values) const;
/**
* Set the parameters for the tabulated function.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param values the tabulated values of the function f(x,y), ordered so that
* values[i+xsize*j] = f(i,j). This must be of length xsize*ysize.
*/
void setFunctionParameters(int xsize, int ysize, const std::vector<double>& values);
private:
int xsize, ysize;
std::vector<double> values;
};
/**
* This is a TabulatedFunction that computes a discrete three dimensional function f(x,y,z).
* To evaluate it, x, y, and z are each rounded to the nearest integer and the table element with those
* indices is returned. If any index is outside the range [0, size), the result is undefined.
*/
class OPENMM_EXPORT Discrete3DFunction : public TabulatedFunction {
public:
/**
* Create a Discrete3DFunction f(x,y,z) based on a set of tabulated values.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param zsize the number of table elements along the z direction
* @param values the tabulated values of the function f(x,y,z), ordered so that
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/
Discrete3DFunction(int xsize, int ysize, int zsize, const std::vector<double>& values);
/**
* Get the parameters for the tabulated function.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param zsize the number of table elements along the z direction
* @param values the tabulated values of the function f(x,y,z), ordered so that
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/
void getFunctionParameters(int& xsize, int& ysize, int& zsize, std::vector<double>& values) const;
/**
* Set the parameters for the tabulated function.
*
* @param xsize the number of table elements along the x direction
* @param ysize the number of table elements along the y direction
* @param zsize the number of table elements along the z direction
* @param values the tabulated values of the function f(x,y,z), ordered so that
* values[i+xsize*j+xsize*ysize*k] = f(i,j,k). This must be of length xsize*ysize*zsize.
*/
void setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values);
private:
int xsize, ysize, zsize;
std::vector<double> values;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_TABULATEDFUNCTION_H_*/ #endif /*OPENMM_TABULATEDFUNCTION_H_*/
...@@ -72,3 +72,50 @@ void Discrete1DFunction::getFunctionParameters(std::vector<double>& values) cons ...@@ -72,3 +72,50 @@ void Discrete1DFunction::getFunctionParameters(std::vector<double>& values) cons
void Discrete1DFunction::setFunctionParameters(const std::vector<double>& values) { void Discrete1DFunction::setFunctionParameters(const std::vector<double>& values) {
this->values = values; this->values = values;
} }
Discrete2DFunction::Discrete2DFunction(int xsize, int ysize, const std::vector<double>& values) {
if (values.size() != xsize*ysize)
throw OpenMMException("Discrete2DFunction: incorrect number of values");
this->xsize = xsize;
this->ysize = ysize;
this->values = values;
}
void Discrete2DFunction::getFunctionParameters(int& xsize, int& ysize, std::vector<double>& values) const {
xsize = this->xsize;
ysize = this->ysize;
values = this->values;
}
void Discrete2DFunction::setFunctionParameters(int xsize, int ysize, const std::vector<double>& values) {
if (values.size() != xsize*ysize)
throw OpenMMException("Discrete2DFunction: incorrect number of values");
this->xsize = xsize;
this->ysize = ysize;
this->values = values;
}
Discrete3DFunction::Discrete3DFunction(int xsize, int ysize, int zsize, const std::vector<double>& values) {
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Discrete3DFunction: incorrect number of values");
this->xsize = xsize;
this->ysize = ysize;
this->zsize = zsize;
this->values = values;
}
void Discrete3DFunction::getFunctionParameters(int& xsize, int& ysize, int& zsize, std::vector<double>& values) const {
xsize = this->xsize;
ysize = this->ysize;
zsize = this->zsize;
values = this->values;
}
void Discrete3DFunction::setFunctionParameters(int xsize, int ysize, int zsize, const std::vector<double>& values) {
if (values.size() != xsize*ysize*zsize)
throw OpenMMException("Discrete3DFunction: incorrect number of values");
this->xsize = xsize;
this->ysize = ysize;
this->zsize = zsize;
this->values = values;
}
...@@ -46,8 +46,7 @@ namespace OpenMM { ...@@ -46,8 +46,7 @@ namespace OpenMM {
class OPENMM_EXPORT_CUDA CudaExpressionUtilities { class OPENMM_EXPORT_CUDA CudaExpressionUtilities {
public: public:
CudaExpressionUtilities(CudaContext& context) : context(context) { CudaExpressionUtilities(CudaContext& context);
}
/** /**
* Generate the source code for calculating a set of expressions. * Generate the source code for calculating a set of expressions.
* *
...@@ -93,38 +92,43 @@ public: ...@@ -93,38 +92,43 @@ public:
* @return the parameter array * @return the parameter array
*/ */
std::vector<float4> computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions); std::vector<float4> computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
class FunctionPlaceholder; /**
* Get a Lepton::CustomFunction that can be used to represent a TabulatedFunction when parsing expressions.
*
* @param function the function for which to get a placeholder
*/
Lepton::CustomFunction* getFunctionPlaceholder(const TabulatedFunction& function);
private: private:
class FunctionPlaceholder : public Lepton::CustomFunction {
public:
FunctionPlaceholder(int numArgs) : numArgs(numArgs) {
}
int getNumArguments() const {
return numArgs;
}
double evaluate(const double* arguments) const {
return 0.0;
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* clone() const {
return new FunctionPlaceholder(numArgs);
}
private:
int numArgs;
};
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node, void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType); const std::string& prefix, const std::string& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps); std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode); std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers); std::map<int, const Lepton::ExpressionTreeNode*>& powers);
CudaContext& context; CudaContext& context;
}; FunctionPlaceholder fp1, fp2, fp3;
/**
* This class serves as a placeholder for custom functions in expressions.
*/
class CudaExpressionUtilities::FunctionPlaceholder : public Lepton::CustomFunction {
public:
int getNumArguments() const {
return 1;
}
double evaluate(const double* arguments) const {
return 0.0;
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* clone() const {
return new FunctionPlaceholder();
}
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -33,6 +33,9 @@ using namespace OpenMM; ...@@ -33,6 +33,9 @@ using namespace OpenMM;
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
CudaExpressionUtilities::CudaExpressionUtilities(CudaContext& context) : context(context), fp1(1), fp2(2), fp3(3) {
}
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables, string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) { const string& functionParams, const string& tempType) {
...@@ -82,7 +85,6 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -82,7 +85,6 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
; ;
if (i == functionNames.size()) if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName()); throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
bool isDeriv = (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 1);
out << "0.0f;\n"; out << "0.0f;\n";
temps.push_back(make_pair(node, name)); temps.push_back(make_pair(node, name));
hasRecordedNode = true; hasRecordedNode = true;
...@@ -90,23 +92,16 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -90,23 +92,16 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
// If both the value and derivative of the function are needed, it's faster to calculate them both // If both the value and derivative of the function are needed, it's faster to calculate them both
// at once, so check to see if both are needed. // at once, so check to see if both are needed.
const ExpressionTreeNode* valueNode = NULL; vector<const ExpressionTreeNode*> nodes;
const ExpressionTreeNode* derivNode = NULL;
for (int j = 0; j < (int) allExpressions.size(); j++) for (int j = 0; j < (int) allExpressions.size(); j++)
findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), valueNode, derivNode); findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), nodes);
string valueName = name; vector<string> nodeNames;
string derivName = name; nodeNames.push_back(name);
if (valueNode != NULL && derivNode != NULL) { for (int j = 1; j < (int) nodes.size(); j++) {
string name2 = prefix+context.intToString(temps.size()); string name2 = prefix+context.intToString(temps.size());
out << tempType << " " << name2 << " = 0.0f;\n"; out << tempType << " " << name2 << " = 0.0f;\n";
if (isDeriv) { nodeNames.push_back(name2);
valueName = name2; temps.push_back(make_pair(*nodes[j], name2));
temps.push_back(make_pair(*valueNode, name2));
}
else {
derivName = name2;
temps.push_back(make_pair(*derivNode, name2));
}
} }
out << "{\n"; out << "{\n";
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) { if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
...@@ -119,20 +114,58 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -119,20 +114,58 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
out << "float4 coeff = " << functionNames[i].second << "[index];\n"; out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n"; out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n"; out << "real a = 1.0f-b;\n";
if (valueNode != NULL) for (int j = 0; j < nodes.size(); j++) {
out << valueName << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n"; const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivNode != NULL) if (derivOrder[0] == 0)
out << derivName << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n"; out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n";
else
out << nodeNames[j] << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n";
}
out << "}\n"; out << "}\n";
} }
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
if (valueNode != NULL) { for (int j = 0; j < nodes.size(); j++) {
out << "float4 params = " << functionParams << "[" << i << "];\n"; const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; if (derivOrder[0] == 0) {
out << "if (x >= 0 && x < params.x) {\n"; out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int index = (int) round(x);\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << valueName << " = " << functionNames[i].second << "[index];\n"; out << "if (x >= 0 && x < params.x) {\n";
out << "}\n"; out << "int index = (int) round(x);\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
}
}
}
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int xsize = (int) params.x;\n";
out << "int ysize = (int) params.y;\n";
out << "int index = x+y*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
}
}
}
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int z = (int) round(" << getTempName(node.getChildren()[2], temps) << ");\n";
out << "int xsize = (int) params.x;\n";
out << "int ysize = (int) params.y;\n";
out << "int zsize = (int) params.z;\n";
out << "int index = x+(y+z*ysize)*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
}
} }
} }
out << "}"; out << "}";
...@@ -327,16 +360,12 @@ string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, cons ...@@ -327,16 +360,12 @@ string CudaExpressionUtilities::getTempName(const ExpressionTreeNode& node, cons
} }
void CudaExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, void CudaExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
const ExpressionTreeNode*& valueNode, const ExpressionTreeNode*& derivNode) { vector<const Lepton::ExpressionTreeNode*>& nodes) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0]) { if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0])
if (dynamic_cast<const Operation::Custom*>(&searchNode.getOperation())->getDerivOrder()[0] == 0) nodes.push_back(&searchNode);
valueNode = &searchNode;
else
derivNode = &searchNode;
}
else else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++) for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], valueNode, derivNode); findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes);
} }
void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) { void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
...@@ -392,6 +421,34 @@ vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const Tabulat ...@@ -392,6 +421,34 @@ vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const Tabulat
width = 1; width = 1;
return f; return f;
} }
if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(function);
int xsize, ysize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(function);
int xsize, ysize, zsize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, zsize, values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
throw OpenMMException("computeFunctionCoefficients: Unknown function type"); throw OpenMMException("computeFunctionCoefficients: Unknown function type");
} }
...@@ -411,8 +468,34 @@ vector<float4> CudaExpressionUtilities::computeFunctionParameters(const vector<c ...@@ -411,8 +468,34 @@ vector<float4> CudaExpressionUtilities::computeFunctionParameters(const vector<c
fn.getFunctionParameters(values); fn.getFunctionParameters(values);
params[i] = make_float4((float) values.size(), 0.0f, 0.0f, 0.0f); params[i] = make_float4((float) values.size(), 0.0f, 0.0f, 0.0f);
} }
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
int xsize, ysize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, values);
params[i] = make_float4(xsize, ysize, 0.0f, 0.0f);
}
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
int xsize, ysize, zsize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, zsize, values);
params[i] = make_float4(xsize, ysize, zsize, 0.0f);
}
else else
throw OpenMMException("computeFunctionParameters: Unknown function type"); throw OpenMMException("computeFunctionParameters: Unknown function type");
} }
return params; return params;
} }
Lepton::CustomFunction* CudaExpressionUtilities::getFunctionPlaceholder(const TabulatedFunction& function) {
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
return &fp1;
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
return &fp1;
if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL)
return &fp2;
if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL)
return &fp3;
throw OpenMMException("getFunctionPlaceholder: Unknown function type");
}
...@@ -1958,7 +1958,6 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const ...@@ -1958,7 +1958,6 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
// Record the tabulated functions. // Record the tabulated functions.
CudaExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
...@@ -1967,7 +1966,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const ...@@ -1967,7 +1966,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
string name = force.getFunctionName(i); string name = force.getFunctionName(i);
string arrayName = prefix+"table"+cu.intToString(i); string arrayName = prefix+"table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
int width; int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
...@@ -2671,7 +2670,6 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -2671,7 +2670,6 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
// Record the tabulated functions. // Record the tabulated functions.
CudaExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
...@@ -2681,7 +2679,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -2681,7 +2679,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
string name = force.getFunctionName(i); string name = force.getFunctionName(i);
string arrayName = prefix+"table"+cu.intToString(i); string arrayName = prefix+"table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
int width; int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
...@@ -3786,7 +3784,6 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust ...@@ -3786,7 +3784,6 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
// Record the tabulated functions. // Record the tabulated functions.
CudaExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
...@@ -3796,7 +3793,7 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust ...@@ -3796,7 +3793,7 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
string name = force.getFunctionName(i); string name = force.getFunctionName(i);
string arrayName = "table"+cu.intToString(i); string arrayName = "table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
int width; int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
...@@ -4182,7 +4179,6 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con ...@@ -4182,7 +4179,6 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
// Record the tabulated functions. // Record the tabulated functions.
CudaExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
...@@ -4190,7 +4186,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con ...@@ -4190,7 +4186,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i)); functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i); string name = force.getFunctionName(i);
functions[name] = &fp; functions[name] = cu.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
int width; int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
CudaArray* array = CudaArray::create<float>(cu, f.size(), "TabulatedFunction"); CudaArray* array = CudaArray::create<float>(cu, f.size(), "TabulatedFunction");
......
...@@ -271,7 +271,7 @@ void testContinuous1DFunction() { ...@@ -271,7 +271,7 @@ void testContinuous1DFunction() {
forceField->addParticle(vector<double>()); forceField->addParticle(vector<double>());
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0)); forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
...@@ -284,8 +284,8 @@ void testContinuous1DFunction() { ...@@ -284,8 +284,8 @@ void testContinuous1DFunction() {
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy); State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces(); const vector<Vec3>& forces = state.getForces();
double force = (x < 1.0 || x > 6.0 ? 0.0 : -std::cos(x-1.0)); double force = (x < 1.0 || x > 6.0 ? 0.0 : -cos(x-1.0));
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1); ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
...@@ -295,7 +295,7 @@ void testContinuous1DFunction() { ...@@ -295,7 +295,7 @@ void testContinuous1DFunction() {
positions[1] = Vec3(x, 0, 0); positions[1] = Vec3(x, 0, 0);
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Energy); State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
} }
} }
...@@ -310,7 +310,7 @@ void testDiscrete1DFunction() { ...@@ -310,7 +310,7 @@ void testDiscrete1DFunction() {
forceField->addParticle(vector<double>()); forceField->addParticle(vector<double>());
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table)); forceField->addFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
...@@ -327,6 +327,74 @@ void testDiscrete1DFunction() { ...@@ -327,6 +327,74 @@ void testDiscrete1DFunction() {
} }
} }
void testDiscrete2DFunction() {
const int xsize = 10;
const int ysize = 5;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1,a)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
table.push_back(sin(0.25*i)+cos(0.33*j));
forceField->addFunction("fn", new Discrete2DFunction(xsize, ysize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3((i%xsize)+1, 0, 0);
context.setPositions(positions);
context.setParameter("a", i/xsize);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testDiscrete3DFunction() {
const int xsize = 8;
const int ysize = 5;
const int zsize = 6;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++)
table.push_back(sin(0.25*i)+cos(0.33*j)+0.12345*k);
forceField->addFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3((i%xsize)+1, 0, 0);
context.setPositions(positions);
context.setParameter("a", (i/xsize)%ysize);
context.setParameter("b", i/(xsize*ysize));
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testCoulombLennardJones() { void testCoulombLennardJones() {
const int numMolecules = 300; const int numMolecules = 300;
const int numParticles = numMolecules*2; const int numParticles = numMolecules*2;
...@@ -754,6 +822,8 @@ int main(int argc, char* argv[]) { ...@@ -754,6 +822,8 @@ int main(int argc, char* argv[]) {
testPeriodic(); testPeriodic();
testContinuous1DFunction(); testContinuous1DFunction();
testDiscrete1DFunction(); testDiscrete1DFunction();
testDiscrete2DFunction();
testDiscrete3DFunction();
testCoulombLennardJones(); testCoulombLennardJones();
testParallelComputation(); testParallelComputation();
testSwitchingFunction(); testSwitchingFunction();
......
...@@ -46,8 +46,7 @@ namespace OpenMM { ...@@ -46,8 +46,7 @@ namespace OpenMM {
class OPENMM_EXPORT_OPENCL OpenCLExpressionUtilities { class OPENMM_EXPORT_OPENCL OpenCLExpressionUtilities {
public: public:
OpenCLExpressionUtilities(OpenCLContext& context) : context(context) { OpenCLExpressionUtilities(OpenCLContext& context);
}
/** /**
* Generate the source code for calculating a set of expressions. * Generate the source code for calculating a set of expressions.
* *
...@@ -93,38 +92,43 @@ public: ...@@ -93,38 +92,43 @@ public:
* @return the parameter array * @return the parameter array
*/ */
std::vector<mm_float4> computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions); std::vector<mm_float4> computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
class FunctionPlaceholder; /**
* Get a Lepton::CustomFunction that can be used to represent a TabulatedFunction when parsing expressions.
*
* @param function the function for which to get a placeholder
*/
Lepton::CustomFunction* getFunctionPlaceholder(const TabulatedFunction& function);
private: private:
class FunctionPlaceholder : public Lepton::CustomFunction {
public:
FunctionPlaceholder(int numArgs) : numArgs(numArgs) {
}
int getNumArguments() const {
return numArgs;
}
double evaluate(const double* arguments) const {
return 0.0;
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* clone() const {
return new FunctionPlaceholder(numArgs);
}
private:
int numArgs;
};
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node, void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType); const std::string& prefix, const std::string& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps); std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode); std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers); std::map<int, const Lepton::ExpressionTreeNode*>& powers);
OpenCLContext& context; OpenCLContext& context;
}; FunctionPlaceholder fp1, fp2, fp3;
/**
* This class serves as a placeholder for custom functions in expressions.
*/
class OpenCLExpressionUtilities::FunctionPlaceholder : public Lepton::CustomFunction {
public:
int getNumArguments() const {
return 1;
}
double evaluate(const double* arguments) const {
return 0.0;
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* clone() const {
return new FunctionPlaceholder();
}
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -33,6 +33,9 @@ using namespace OpenMM; ...@@ -33,6 +33,9 @@ using namespace OpenMM;
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
OpenCLExpressionUtilities::OpenCLExpressionUtilities(OpenCLContext& context) : context(context), fp1(1), fp2(2), fp3(3) {
}
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables, string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) { const string& functionParams, const string& tempType) {
...@@ -82,7 +85,6 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -82,7 +85,6 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
; ;
if (i == functionNames.size()) if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName()); throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
bool isDeriv = (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 1);
out << "0.0f;\n"; out << "0.0f;\n";
temps.push_back(make_pair(node, name)); temps.push_back(make_pair(node, name));
hasRecordedNode = true; hasRecordedNode = true;
...@@ -90,23 +92,16 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -90,23 +92,16 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
// If both the value and derivative of the function are needed, it's faster to calculate them both // If both the value and derivative of the function are needed, it's faster to calculate them both
// at once, so check to see if both are needed. // at once, so check to see if both are needed.
const ExpressionTreeNode* valueNode = NULL; vector<const ExpressionTreeNode*> nodes;
const ExpressionTreeNode* derivNode = NULL;
for (int j = 0; j < (int) allExpressions.size(); j++) for (int j = 0; j < (int) allExpressions.size(); j++)
findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), valueNode, derivNode); findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), nodes);
string valueName = name; vector<string> nodeNames;
string derivName = name; nodeNames.push_back(name);
if (valueNode != NULL && derivNode != NULL) { for (int j = 1; j < (int) nodes.size(); j++) {
string name2 = prefix+context.intToString(temps.size()); string name2 = prefix+context.intToString(temps.size());
out << tempType << " " << name2 << " = 0.0f;\n"; out << tempType << " " << name2 << " = 0.0f;\n";
if (isDeriv) { nodeNames.push_back(name2);
valueName = name2; temps.push_back(make_pair(*nodes[j], name2));
temps.push_back(make_pair(*valueNode, name2));
}
else {
derivName = name2;
temps.push_back(make_pair(*derivNode, name2));
}
} }
out << "{\n"; out << "{\n";
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) { if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
...@@ -119,20 +114,58 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -119,20 +114,58 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out << "float4 coeff = " << functionNames[i].second << "[index];\n"; out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n"; out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n"; out << "real a = 1.0f-b;\n";
if (valueNode != NULL) for (int j = 0; j < nodes.size(); j++) {
out << valueName << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n"; const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivNode != NULL) if (derivOrder[0] == 0)
out << derivName << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n"; out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n";
else
out << nodeNames[j] << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n";
}
out << "}\n"; out << "}\n";
} }
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
if (valueNode != NULL) { for (int j = 0; j < nodes.size(); j++) {
out << "float4 params = " << functionParams << "[" << i << "];\n"; const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; if (derivOrder[0] == 0) {
out << "if (x >= 0 && x < params.x) {\n"; out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int index = (int) round(x);\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << valueName << " = " << functionNames[i].second << "[index];\n"; out << "if (x >= 0 && x < params.x) {\n";
out << "}\n"; out << "int index = (int) round(x);\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
}
}
}
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int xsize = (int) params.x;\n";
out << "int ysize = (int) params.y;\n";
out << "int index = x+y*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
}
}
}
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int z = (int) round(" << getTempName(node.getChildren()[2], temps) << ");\n";
out << "int xsize = (int) params.x;\n";
out << "int ysize = (int) params.y;\n";
out << "int zsize = (int) params.z;\n";
out << "int index = x+(y+z*ysize)*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
}
} }
} }
out << "}"; out << "}";
...@@ -327,16 +360,12 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co ...@@ -327,16 +360,12 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
} }
void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
const ExpressionTreeNode*& valueNode, const ExpressionTreeNode*& derivNode) { vector<const Lepton::ExpressionTreeNode*>& nodes) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0]) { if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0])
if (dynamic_cast<const Operation::Custom*>(&searchNode.getOperation())->getDerivOrder()[0] == 0) nodes.push_back(&searchNode);
valueNode = &searchNode;
else
derivNode = &searchNode;
}
else else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++) for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], valueNode, derivNode); findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], nodes);
} }
void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) { void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
...@@ -392,6 +421,34 @@ vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const Tabul ...@@ -392,6 +421,34 @@ vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const Tabul
width = 1; width = 1;
return f; return f;
} }
if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(function);
int xsize, ysize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(function);
int xsize, ysize, zsize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, zsize, values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
throw OpenMMException("computeFunctionCoefficients: Unknown function type"); throw OpenMMException("computeFunctionCoefficients: Unknown function type");
} }
...@@ -411,8 +468,34 @@ vector<mm_float4> OpenCLExpressionUtilities::computeFunctionParameters(const vec ...@@ -411,8 +468,34 @@ vector<mm_float4> OpenCLExpressionUtilities::computeFunctionParameters(const vec
fn.getFunctionParameters(values); fn.getFunctionParameters(values);
params[i] = mm_float4((float) values.size(), 0.0f, 0.0f, 0.0f); params[i] = mm_float4((float) values.size(), 0.0f, 0.0f, 0.0f);
} }
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
int xsize, ysize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, values);
params[i] = mm_float4(xsize, ysize, 0.0f, 0.0f);
}
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
int xsize, ysize, zsize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, zsize, values);
params[i] = mm_float4(xsize, ysize, zsize, 0.0f);
}
else else
throw OpenMMException("computeFunctionParameters: Unknown function type"); throw OpenMMException("computeFunctionParameters: Unknown function type");
} }
return params; return params;
} }
Lepton::CustomFunction* OpenCLExpressionUtilities::getFunctionPlaceholder(const TabulatedFunction& function) {
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
return &fp1;
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
return &fp1;
if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL)
return &fp2;
if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL)
return &fp3;
throw OpenMMException("getFunctionPlaceholder: Unknown function type");
}
...@@ -1968,7 +1968,6 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -1968,7 +1968,6 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
// Record the tabulated functions. // Record the tabulated functions.
OpenCLExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
...@@ -1977,7 +1976,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -1977,7 +1976,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
string name = force.getFunctionName(i); string name = force.getFunctionName(i);
string arrayName = prefix+"table"+cl.intToString(i); string arrayName = prefix+"table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
int width; int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
...@@ -2724,7 +2723,6 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2724,7 +2723,6 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
// Record the tabulated functions. // Record the tabulated functions.
OpenCLExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
...@@ -2734,7 +2732,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2734,7 +2732,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
string name = force.getFunctionName(i); string name = force.getFunctionName(i);
string arrayName = prefix+"table"+cl.intToString(i); string arrayName = prefix+"table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
int width; int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
...@@ -3949,7 +3947,6 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -3949,7 +3947,6 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
// Record the tabulated functions. // Record the tabulated functions.
OpenCLExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
...@@ -3959,7 +3956,7 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -3959,7 +3956,7 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
string name = force.getFunctionName(i); string name = force.getFunctionName(i);
string arrayName = "table"+cl.intToString(i); string arrayName = "table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
int width; int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
...@@ -4347,7 +4344,6 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -4347,7 +4344,6 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
// Record the tabulated functions. // Record the tabulated functions.
OpenCLExpressionUtilities::FunctionPlaceholder fp;
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
...@@ -4355,7 +4351,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -4355,7 +4351,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getFunction(i)); functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i); string name = force.getFunctionName(i);
functions[name] = &fp; functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getFunction(i));
int width; int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width); vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
OpenCLArray* array = OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"); OpenCLArray* array = OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction");
......
...@@ -271,7 +271,7 @@ void testContinuous1DFunction() { ...@@ -271,7 +271,7 @@ void testContinuous1DFunction() {
forceField->addParticle(vector<double>()); forceField->addParticle(vector<double>());
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0)); forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
...@@ -284,8 +284,8 @@ void testContinuous1DFunction() { ...@@ -284,8 +284,8 @@ void testContinuous1DFunction() {
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy); State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces(); const vector<Vec3>& forces = state.getForces();
double force = (x < 1.0 || x > 6.0 ? 0.0 : -std::cos(x-1.0)); double force = (x < 1.0 || x > 6.0 ? 0.0 : -cos(x-1.0));
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1); ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
...@@ -295,7 +295,7 @@ void testContinuous1DFunction() { ...@@ -295,7 +295,7 @@ void testContinuous1DFunction() {
positions[1] = Vec3(x, 0, 0); positions[1] = Vec3(x, 0, 0);
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Energy); State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
} }
} }
...@@ -310,7 +310,7 @@ void testDiscrete1DFunction() { ...@@ -310,7 +310,7 @@ void testDiscrete1DFunction() {
forceField->addParticle(vector<double>()); forceField->addParticle(vector<double>());
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table)); forceField->addFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
...@@ -326,6 +326,73 @@ void testDiscrete1DFunction() { ...@@ -326,6 +326,73 @@ void testDiscrete1DFunction() {
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6); ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
} }
} }
void testDiscrete2DFunction() {
const int xsize = 10;
const int ysize = 5;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1,a)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
table.push_back(sin(0.25*i)+cos(0.33*j));
forceField->addFunction("fn", new Discrete2DFunction(xsize, ysize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3((i%xsize)+1, 0, 0);
context.setPositions(positions);
context.setParameter("a", i/xsize);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testDiscrete3DFunction() {
const int xsize = 8;
const int ysize = 5;
const int zsize = 6;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++)
table.push_back(sin(0.25*i)+cos(0.33*j)+0.12345*k);
forceField->addFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3((i%xsize)+1, 0, 0);
context.setPositions(positions);
context.setParameter("a", (i/xsize)%ysize);
context.setParameter("b", i/(xsize*ysize));
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testCoulombLennardJones() { void testCoulombLennardJones() {
const int numMolecules = 300; const int numMolecules = 300;
...@@ -754,6 +821,8 @@ int main(int argc, char* argv[]) { ...@@ -754,6 +821,8 @@ int main(int argc, char* argv[]) {
testPeriodic(); testPeriodic();
testContinuous1DFunction(); testContinuous1DFunction();
testDiscrete1DFunction(); testDiscrete1DFunction();
testDiscrete2DFunction();
testDiscrete3DFunction();
testCoulombLennardJones(); testCoulombLennardJones();
testParallelComputation(); testParallelComputation();
testSwitchingFunction(); testSwitchingFunction();
......
...@@ -75,6 +75,38 @@ private: ...@@ -75,6 +75,38 @@ private:
std::vector<double> values; std::vector<double> values;
}; };
/**
* This class adapts a Discrete2DFunction into a Lepton::CustomFunction.
*/
class OPENMM_EXPORT ReferenceDiscrete2DFunction : public Lepton::CustomFunction {
public:
ReferenceDiscrete2DFunction(const Discrete2DFunction& function);
int getNumArguments() const;
double evaluate(const double* arguments) const;
double evaluateDerivative(const double* arguments, const int* derivOrder) const;
CustomFunction* clone() const;
private:
const Discrete2DFunction& function;
int xsize, ysize;
std::vector<double> values;
};
/**
* This class adapts a Discrete3DFunction into a Lepton::CustomFunction.
*/
class OPENMM_EXPORT ReferenceDiscrete3DFunction : public Lepton::CustomFunction {
public:
ReferenceDiscrete3DFunction(const Discrete3DFunction& function);
int getNumArguments() const;
double evaluate(const double* arguments) const;
double evaluateDerivative(const double* arguments, const int* derivOrder) const;
CustomFunction* clone() const;
private:
const Discrete3DFunction& function;
int xsize, ysize, zsize;
std::vector<double> values;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_REFERENCETABULATEDFUNCTION_H_*/ #endif /*OPENMM_REFERENCETABULATEDFUNCTION_H_*/
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "ReferenceTabulatedFunction.h" #include "ReferenceTabulatedFunction.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/internal/SplineFitter.h" #include "openmm/internal/SplineFitter.h"
#include <cmath>
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
...@@ -42,6 +43,10 @@ extern "C" CustomFunction* createReferenceTabulatedFunction(const TabulatedFunct ...@@ -42,6 +43,10 @@ extern "C" CustomFunction* createReferenceTabulatedFunction(const TabulatedFunct
return new ReferenceContinuous1DFunction(dynamic_cast<const Continuous1DFunction&>(function)); return new ReferenceContinuous1DFunction(dynamic_cast<const Continuous1DFunction&>(function));
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL) if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
return new ReferenceDiscrete1DFunction(dynamic_cast<const Discrete1DFunction&>(function)); return new ReferenceDiscrete1DFunction(dynamic_cast<const Discrete1DFunction&>(function));
if (dynamic_cast<const Discrete2DFunction*>(&function) != NULL)
return new ReferenceDiscrete2DFunction(dynamic_cast<const Discrete2DFunction&>(function));
if (dynamic_cast<const Discrete3DFunction*>(&function) != NULL)
return new ReferenceDiscrete3DFunction(dynamic_cast<const Discrete3DFunction&>(function));
throw OpenMMException("createReferenceTabulatedFunction: Unknown function type"); throw OpenMMException("createReferenceTabulatedFunction: Unknown function type");
} }
...@@ -85,10 +90,10 @@ int ReferenceDiscrete1DFunction::getNumArguments() const { ...@@ -85,10 +90,10 @@ int ReferenceDiscrete1DFunction::getNumArguments() const {
} }
double ReferenceDiscrete1DFunction::evaluate(const double* arguments) const { double ReferenceDiscrete1DFunction::evaluate(const double* arguments) const {
int t = (int) arguments[0]; int i = (int) round(arguments[0]);
if (t < 0 || t >= values.size()) if (i < 0 || i >= values.size())
throw OpenMMException("ReferenceDiscrete1DFunction: argument out of range"); throw OpenMMException("ReferenceDiscrete1DFunction: argument out of range");
return values[t]; return values[i];
} }
double ReferenceDiscrete1DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const { double ReferenceDiscrete1DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
...@@ -98,3 +103,52 @@ double ReferenceDiscrete1DFunction::evaluateDerivative(const double* arguments, ...@@ -98,3 +103,52 @@ double ReferenceDiscrete1DFunction::evaluateDerivative(const double* arguments,
CustomFunction* ReferenceDiscrete1DFunction::clone() const { CustomFunction* ReferenceDiscrete1DFunction::clone() const {
return new ReferenceDiscrete1DFunction(function); return new ReferenceDiscrete1DFunction(function);
} }
ReferenceDiscrete2DFunction::ReferenceDiscrete2DFunction(const Discrete2DFunction& function) : function(function) {
function.getFunctionParameters(xsize, ysize, values);
}
int ReferenceDiscrete2DFunction::getNumArguments() const {
return 2;
}
double ReferenceDiscrete2DFunction::evaluate(const double* arguments) const {
int i = (int) round(arguments[0]);
int j = (int) round(arguments[1]);
if (i < 0 || i >= xsize || j < 0 || j >= ysize)
throw OpenMMException("ReferenceDiscrete2DFunction: argument out of range");
return values[i+j*xsize];
}
double ReferenceDiscrete2DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* ReferenceDiscrete2DFunction::clone() const {
return new ReferenceDiscrete2DFunction(function);
}
ReferenceDiscrete3DFunction::ReferenceDiscrete3DFunction(const Discrete3DFunction& function) : function(function) {
function.getFunctionParameters(xsize, ysize, zsize, values);
}
int ReferenceDiscrete3DFunction::getNumArguments() const {
return 3;
}
double ReferenceDiscrete3DFunction::evaluate(const double* arguments) const {
int i = (int) round(arguments[0]);
int j = (int) round(arguments[1]);
int k = (int) round(arguments[2]);
if (i < 0 || i >= xsize || j < 0 || j >= ysize || k < 0 || k >= zsize)
throw OpenMMException("ReferenceDiscrete3DFunction: argument out of range");
return values[i+(j+k*ysize)*xsize];
}
double ReferenceDiscrete3DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* ReferenceDiscrete3DFunction::clone() const {
return new ReferenceDiscrete3DFunction(function);
}
...@@ -238,7 +238,7 @@ void testContinuous1DFunction() { ...@@ -238,7 +238,7 @@ void testContinuous1DFunction() {
forceField->addParticle(vector<double>()); forceField->addParticle(vector<double>());
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0)); forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
...@@ -251,8 +251,8 @@ void testContinuous1DFunction() { ...@@ -251,8 +251,8 @@ void testContinuous1DFunction() {
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy); State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces(); const vector<Vec3>& forces = state.getForces();
double force = (x < 1.0 || x > 6.0 ? 0.0 : -std::cos(x-1.0)); double force = (x < 1.0 || x > 6.0 ? 0.0 : -cos(x-1.0));
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1); ASSERT_EQUAL_VEC(Vec3(-force, 0, 0), forces[0], 0.1);
ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1); ASSERT_EQUAL_VEC(Vec3(force, 0, 0), forces[1], 0.1);
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 0.02);
...@@ -262,7 +262,7 @@ void testContinuous1DFunction() { ...@@ -262,7 +262,7 @@ void testContinuous1DFunction() {
positions[1] = Vec3(x, 0, 0); positions[1] = Vec3(x, 0, 0);
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Energy); State state = context.getState(State::Energy);
double energy = (x < 1.0 || x > 6.0 ? 0.0 : std::sin(x-1.0))+1.0; double energy = (x < 1.0 || x > 6.0 ? 0.0 : sin(x-1.0))+1.0;
ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4); ASSERT_EQUAL_TOL(energy, state.getPotentialEnergy(), 1e-4);
} }
} }
...@@ -278,7 +278,7 @@ void testDiscrete1DFunction() { ...@@ -278,7 +278,7 @@ void testDiscrete1DFunction() {
forceField->addParticle(vector<double>()); forceField->addParticle(vector<double>());
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table)); forceField->addFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
...@@ -295,6 +295,76 @@ void testDiscrete1DFunction() { ...@@ -295,6 +295,76 @@ void testDiscrete1DFunction() {
} }
} }
void testDiscrete2DFunction() {
const int xsize = 10;
const int ysize = 5;
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
table.push_back(sin(0.25*i)+cos(0.33*j));
forceField->addFunction("fn", new Discrete2DFunction(xsize, ysize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i%xsize, 0, 0);
context.setPositions(positions);
context.setParameter("a", i/xsize);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL(table[i]+1.0, state.getPotentialEnergy());
}
}
void testDiscrete3DFunction() {
const int xsize = 8;
const int ysize = 5;
const int zsize = 6;
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r,a,b)+1");
forceField->addGlobalParameter("a", 0.0);
forceField->addGlobalParameter("b", 0.0);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < xsize; i++)
for (int j = 0; j < ysize; j++)
for (int k = 0; k < zsize; k++)
table.push_back(sin(0.25*i)+cos(0.33*j)+0.12345*k);
forceField->addFunction("fn", new Discrete3DFunction(xsize, ysize, zsize, table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i%xsize, 0, 0);
context.setPositions(positions);
context.setParameter("a", (i/xsize)%ysize);
context.setParameter("b", i/(xsize*ysize));
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL(table[i]+1.0, state.getPotentialEnergy());
}
}
void testCoulombLennardJones() { void testCoulombLennardJones() {
const int numMolecules = 300; const int numMolecules = 300;
const int numParticles = numMolecules*2; const int numParticles = numMolecules*2;
...@@ -688,6 +758,8 @@ int main() { ...@@ -688,6 +758,8 @@ int main() {
testPeriodic(); testPeriodic();
testContinuous1DFunction(); testContinuous1DFunction();
testDiscrete1DFunction(); testDiscrete1DFunction();
testDiscrete2DFunction();
testDiscrete3DFunction();
testCoulombLennardJones(); testCoulombLennardJones();
testSwitchingFunction(); testSwitchingFunction();
testLongRangeCorrection(); testLongRangeCorrection();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment