"vscode:/vscode.git/clone" did not exist on "19a599e12dd92b80469b7c626d805765d34e8a49"
Commit 56e36449 authored by peastman's avatar peastman
Browse files

Created Discrete1DFunction

parent 81f26683
...@@ -101,6 +101,36 @@ private: ...@@ -101,6 +101,36 @@ private:
double min, max; double min, max;
}; };
/**
* This is a TabulatedFunction that computes a discrete one dimensional function.
*/
class OPENMM_EXPORT Discrete1DFunction : public TabulatedFunction {
public:
/**
* Create a Discrete1DFunction f(x) based on a set of tabulated values.
*
* @param values the tabulated values of the function f(x). The function is only defined
* for integer values of x in the range [0, values.size()].
*/
Discrete1DFunction(const std::vector<double>& values);
/**
* Get the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x). The function is only defined
* for integer values of x in the range [0, values.size()].
*/
void getFunctionParameters(std::vector<double>& values) const;
/**
* Set the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x). The function is only defined
* for integer values of x in the range [0, values.size()].
*/
void setFunctionParameters(const std::vector<double>& values);
private:
std::vector<double> values;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_TABULATEDFUNCTION_H_*/ #endif /*OPENMM_TABULATEDFUNCTION_H_*/
...@@ -60,3 +60,15 @@ void Continuous1DFunction::setFunctionParameters(const std::vector<double>& valu ...@@ -60,3 +60,15 @@ void Continuous1DFunction::setFunctionParameters(const std::vector<double>& valu
this->min = min; this->min = min;
this->max = max; this->max = max;
} }
Discrete1DFunction::Discrete1DFunction(const std::vector<double>& values) {
this->values = values;
}
void Discrete1DFunction::getFunctionParameters(std::vector<double>& values) const {
values = this->values;
}
void Discrete1DFunction::setFunctionParameters(const std::vector<double>& values) {
this->values = values;
}
...@@ -54,33 +54,38 @@ public: ...@@ -54,33 +54,38 @@ public:
* @param expressions the expressions to generate code for (keys are the variables to store the output values in) * @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable that may appear in the expressions. Keys are * @param variables defines the source code to generate for each variable that may appear in the expressions. Keys are
* variable names, and the values are the code to generate for them. * variable names, and the values are the code to generate for them.
* @param functions defines the variable name for each tabulated function that may appear in the expressions * @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables * @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function * @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real") * @param tempType the type of value to use for temporary variables (defaults to "real")
*/ */
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables, std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::string& tempType="real"); const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="real");
/** /**
* Generate the source code for calculating a set of expressions. * Generate the source code for calculating a set of expressions.
* *
* @param expressions the expressions to generate code for (keys are the variables to store the output values in) * @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable or precomputed sub-expression that may appear in the expressions. * @param variables defines the source code to generate for each variable or precomputed sub-expression that may appear in the expressions.
* Each entry is an ExpressionTreeNode, and the code to generate wherever an identical node appears. * Each entry is an ExpressionTreeNode, and the code to generate wherever an identical node appears.
* @param functions defines the variable name for each tabulated function that may appear in the expressions * @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables * @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function * @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real") * @param tempType the type of value to use for temporary variables (defaults to "real")
*/ */
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables, std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::string& tempType="real"); const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="real");
/** /**
* Calculate the spline coefficients for a tabulated function that appears in expressions. * Calculate the spline coefficients for a tabulated function that appears in expressions.
* *
* @param function the function for which to compute coefficients * @param function the function for which to compute coefficients
* @param width on output, the number of floats used for each value
* @return the spline coefficients * @return the spline coefficients
*/ */
std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function); std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function, int& width);
/** /**
* Given the list of TabulatedFunctions used by a Force, create the parameter array describing them. * Given the list of TabulatedFunctions used by a Force, create the parameter array describing them.
* *
...@@ -92,8 +97,8 @@ public: ...@@ -92,8 +97,8 @@ public:
private: private:
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node, void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType); const std::string& prefix, const std::string& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps); std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode); const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode);
......
...@@ -34,34 +34,37 @@ using namespace Lepton; ...@@ -34,34 +34,37 @@ using namespace Lepton;
using namespace std; using namespace std;
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables, string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) { const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) {
vector<pair<ExpressionTreeNode, string> > variableNodes; vector<pair<ExpressionTreeNode, string> > variableNodes;
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter) for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second)); variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, prefix, functionParams, tempType); return createExpressions(expressions, variableNodes, functions, functionNames, prefix, functionParams, tempType);
} }
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables, string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) { const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) {
stringstream out; stringstream out;
vector<ParsedExpression> allExpressions; vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second); allExpressions.push_back(iter->second);
vector<pair<ExpressionTreeNode, string> > temps = variables; vector<pair<ExpressionTreeNode, string> > temps = variables;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) { for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, functions, prefix, functionParams, allExpressions, tempType); processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n"; out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
} }
return out.str(); return out.str();
} }
void CudaExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps, void CudaExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const vector<ParsedExpression>& allExpressions, const string& tempType) { const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& functionParams,
const vector<ParsedExpression>& allExpressions, const string& tempType) {
for (int i = 0; i < (int) temps.size(); i++) for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node) if (temps[i].first == node)
return; return;
for (int i = 0; i < (int) node.getChildren().size(); i++) for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, functions, prefix, functionParams, allExpressions, tempType); processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
string name = prefix+context.intToString(temps.size()); string name = prefix+context.intToString(temps.size());
bool hasRecordedNode = false; bool hasRecordedNode = false;
...@@ -75,9 +78,9 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -75,9 +78,9 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
case Operation::CUSTOM: case Operation::CUSTOM:
{ {
int i; int i;
for (i = 0; i < (int) functions.size() && functions[i].first != node.getOperation().getName(); i++) for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
; ;
if (i == functions.size()) if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName()); throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
bool isDeriv = (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 1); bool isDeriv = (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 1);
out << "0.0f;\n"; out << "0.0f;\n";
...@@ -106,20 +109,32 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -106,20 +109,32 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
} }
} }
out << "{\n"; out << "{\n";
out << "float4 params = " << functionParams << "[" << i << "];\n"; if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "if (x >= params.x && x <= params.y) {\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "x = (x-params.x)*params.z;\n"; out << "if (x >= params.x && x <= params.y) {\n";
out << "int index = (int) (floor(x));\n"; out << "x = (x-params.x)*params.z;\n";
out << "index = min(index, (int) params.w);\n"; out << "int index = (int) (floor(x));\n";
out << "float4 coeff = " << functions[i].second << "[index];\n"; out << "index = min(index, (int) params.w);\n";
out << "real b = x-index;\n"; out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real a = 1.0f-b;\n"; out << "real b = x-index;\n";
if (valueNode != NULL) out << "real a = 1.0f-b;\n";
out << valueName << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n"; if (valueNode != NULL)
if (derivNode != NULL) out << valueName << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n";
out << derivName << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n"; if (derivNode != NULL)
out << "}\n"; out << derivName << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n";
out << "}\n";
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
if (valueNode != NULL) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < params.x) {\n";
out << "int index = (int) round(x);\n";
out << valueName << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
}
}
out << "}"; out << "}";
break; break;
} }
...@@ -341,10 +356,10 @@ void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, ...@@ -341,10 +356,10 @@ void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node,
findRelatedPowers(node, searchNode.getChildren()[i], powers); findRelatedPowers(node, searchNode.getChildren()[i], powers);
} }
vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function) { vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) {
// Compute the spline coefficients.
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) { if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
// Compute the spline coefficients.
const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function); const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function);
vector<double> values; vector<double> values;
double min, max; double min, max;
...@@ -361,6 +376,20 @@ vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const Tabulat ...@@ -361,6 +376,20 @@ vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const Tabulat
f[4*i+2] = (float) (derivs[i]/6.0); f[4*i+2] = (float) (derivs[i]/6.0);
f[4*i+3] = (float) (derivs[i+1]/6.0); f[4*i+3] = (float) (derivs[i+1]/6.0);
} }
width = 4;
return f;
}
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(function);
vector<double> values;
fn.getFunctionParameters(values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f; return f;
} }
throw OpenMMException("computeFunctionCoefficients: Unknown function type"); throw OpenMMException("computeFunctionCoefficients: Unknown function type");
...@@ -376,6 +405,12 @@ vector<float4> CudaExpressionUtilities::computeFunctionParameters(const vector<c ...@@ -376,6 +405,12 @@ vector<float4> CudaExpressionUtilities::computeFunctionParameters(const vector<c
fn.getFunctionParameters(values, min, max); fn.getFunctionParameters(values, min, max);
params[i] = make_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2); params[i] = make_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2);
} }
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
vector<double> values;
fn.getFunctionParameters(values);
params[i] = make_float4((float) values.size(), 0.0f, 0.0f, 0.0f);
}
else else
throw OpenMMException("computeFunctionParameters: Unknown function type"); throw OpenMMException("computeFunctionParameters: Unknown function type");
} }
......
...@@ -595,8 +595,9 @@ void CudaCalcCustomBondForceKernel::initialize(const System& system, const Custo ...@@ -595,8 +595,9 @@ void CudaCalcCustomBondForceKernel::initialize(const System& system, const Custo
string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::bondForce, replacements), force.getForceGroup()); cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::bondForce, replacements), force.getForceGroup());
...@@ -830,8 +831,9 @@ void CudaCalcCustomAngleForceKernel::initialize(const System& system, const Cust ...@@ -830,8 +831,9 @@ void CudaCalcCustomAngleForceKernel::initialize(const System& system, const Cust
string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" angleParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" angleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::angleForce, replacements), force.getForceGroup()); cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::angleForce, replacements), force.getForceGroup());
...@@ -1253,8 +1255,9 @@ void CudaCalcCustomTorsionForceKernel::initialize(const System& system, const Cu ...@@ -1253,8 +1255,9 @@ void CudaCalcCustomTorsionForceKernel::initialize(const System& system, const Cu
string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" torsionParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" torsionParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::torsionForce, replacements), force.getForceGroup()); cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::torsionForce, replacements), force.getForceGroup());
...@@ -1965,10 +1968,11 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const ...@@ -1965,10 +1968,11 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
string arrayName = prefix+"table"+cu.intToString(i); string arrayName = prefix+"table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = &fp;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i)); int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer())); cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
} }
vector<float4> tabulatedFunctionParamsVec = cu.getExpressionUtilities().computeFunctionParameters(functionList); vector<float4> tabulatedFunctionParamsVec = cu.getExpressionUtilities().computeFunctionParameters(functionList);
if (force.getNumFunctions() > 0) { if (force.getNumFunctions() > 0) {
...@@ -2013,7 +2017,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const ...@@ -2013,7 +2017,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
variables.push_back(makeVariable(name, prefix+value)); variables.push_back(makeVariable(name, prefix+value));
} }
stringstream compute; stringstream compute;
compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, prefix+"temp", prefix+"functionParams"); compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, prefix+"temp", prefix+"functionParams");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0"); replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0");
...@@ -2678,10 +2682,11 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -2678,10 +2682,11 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
string arrayName = prefix+"table"+cu.intToString(i); string arrayName = prefix+"table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = &fp;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i)); int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer())); cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
tableArgs << ", const float4* __restrict__ " << arrayName; tableArgs << ", const float4* __restrict__ " << arrayName;
} }
vector<float4> tabulatedFunctionParamsVec = cu.getExpressionUtilities().computeFunctionParameters(functionList); vector<float4> tabulatedFunctionParamsVec = cu.getExpressionUtilities().computeFunctionParameters(functionList);
...@@ -2775,7 +2780,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -2775,7 +2780,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize(); Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
n2ValueExpressions["tempValue1 = "] = ex; n2ValueExpressions["tempValue1 = "] = ex;
n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename); n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename);
n2ValueSource << cu.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); n2ValueSource << cu.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
map<string, string> replacements; map<string, string> replacements;
string n2ValueStr = n2ValueSource.str(); string n2ValueStr = n2ValueSource.str();
replacements["COMPUTE_VALUE"] = n2ValueStr; replacements["COMPUTE_VALUE"] = n2ValueStr;
...@@ -2853,7 +2858,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -2853,7 +2858,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1); variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1);
map<string, Lepton::ParsedExpression> valueExpressions; map<string, Lepton::ParsedExpression> valueExpressions;
valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize(); valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
reductionSource << cu.getExpressionUtilities().createExpressions(valueExpressions, variables, functionDefinitions, "value"+cu.intToString(i)+"_temp", prefix+"functionParams"); reductionSource << cu.getExpressionUtilities().createExpressions(valueExpressions, variables, functionList, functionDefinitions, "value"+cu.intToString(i)+"_temp", prefix+"functionParams");
} }
for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) { for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
string valueName = "values"+cu.intToString(i+1); string valueName = "values"+cu.intToString(i+1);
...@@ -2907,7 +2912,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -2907,7 +2912,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
} }
if (exclude) if (exclude)
n2EnergySource << "if (!isExcluded) {\n"; n2EnergySource << "if (!isExcluded) {\n";
n2EnergySource << cu.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); n2EnergySource << cu.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
if (exclude) if (exclude)
n2EnergySource << "}\n"; n2EnergySource << "}\n";
} }
...@@ -3056,7 +3061,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -3056,7 +3061,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
for (int i = 1; i < force.getNumComputedValues(); i++) for (int i = 1; i < force.getNumComputedValues(); i++)
for (int j = 0; j < i; j++) for (int j = 0; j < i; j++)
expressions["real dV"+cu.intToString(i)+"dV"+cu.intToString(j)+" = "] = valueDerivExpressions[i][j]; expressions["real dV"+cu.intToString(i)+"dV"+cu.intToString(j)+" = "] = valueDerivExpressions[i][j];
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functionDefinitions, "temp", prefix+"functionParams"); compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
// Record values. // Record values.
...@@ -3128,7 +3133,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -3128,7 +3133,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
map<string, Lepton::ParsedExpression> derivExpressions; map<string, Lepton::ParsedExpression> derivExpressions;
string js = cu.intToString(j); string js = cu.intToString(j);
derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j]; derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j];
compute << cu.getExpressionUtilities().createExpressions(derivExpressions, variables, functionDefinitions, "temp_"+is+"_"+js, prefix+"functionParams"); compute << cu.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "temp_"+is+"_"+js, prefix+"functionParams");
compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n"; compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n";
} }
} }
...@@ -3139,7 +3144,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -3139,7 +3144,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1]; gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
if (!isZeroExpression(valueGradientExpressions[i][2])) if (!isZeroExpression(valueGradientExpressions[i][2]))
gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2]; gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
compute << cu.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); compute << cu.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
} }
for (int i = 1; i < force.getNumComputedValues(); i++) { for (int i = 1; i < force.getNumComputedValues(); i++) {
string is = cu.intToString(i); string is = cu.intToString(i);
...@@ -3181,7 +3186,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG ...@@ -3181,7 +3186,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize(); Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize();
derivExpressions["real dV0dR1 = "] = dVdR; derivExpressions["real dV0dR1 = "] = dVdR;
derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename); derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename);
chainSource << cu.getExpressionUtilities().createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp0_", prefix+"functionParams"); chainSource << cu.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, prefix+"temp0_", prefix+"functionParams");
if (needChainForValue[0]) { if (needChainForValue[0]) {
if (useExclusionsForValue) if (useExclusionsForValue)
chainSource << "if (!isExcluded) {\n"; chainSource << "if (!isExcluded) {\n";
...@@ -3539,8 +3544,9 @@ void CudaCalcCustomExternalForceKernel::initialize(const System& system, const C ...@@ -3539,8 +3544,9 @@ void CudaCalcCustomExternalForceKernel::initialize(const System& system, const C
string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" particleParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" particleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::customExternalForce, replacements), force.getForceGroup()); cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::customExternalForce, replacements), force.getForceGroup());
...@@ -3791,10 +3797,14 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust ...@@ -3791,10 +3797,14 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
string arrayName = "table"+cu.intToString(i); string arrayName = "table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = &fp;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i)); int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
tableArgs << ", const float4* __restrict__ " << arrayName; tableArgs << ", const float";
if (width > 1)
tableArgs << width;
tableArgs << "* __restrict__ " << arrayName;
} }
vector<float4> tabulatedFunctionParamsVec = cu.getExpressionUtilities().computeFunctionParameters(functionList); vector<float4> tabulatedFunctionParamsVec = cu.getExpressionUtilities().computeFunctionParameters(functionList);
if (force.getNumFunctions() > 0) { if (force.getNumFunctions() > 0) {
...@@ -3916,9 +3926,9 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust ...@@ -3916,9 +3926,9 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
// Now evaluate the expressions. // Now evaluate the expressions.
computeAcceptor << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", "functionParams"); computeAcceptor << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp", "functionParams");
forceExpressions["energy += "] = energyExpression; forceExpressions["energy += "] = energyExpression;
computeDonor << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", "functionParams"); computeDonor << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp", "functionParams");
// Finally, apply forces to atoms. // Finally, apply forces to atoms.
...@@ -4181,11 +4191,12 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con ...@@ -4181,11 +4191,12 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
functionList.push_back(&force.getFunction(i)); functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i); string name = force.getFunctionName(i);
functions[name] = &fp; functions[name] = &fp;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i)); int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
CudaArray* array = CudaArray::create<float>(cu, f.size(), "TabulatedFunction"); CudaArray* array = CudaArray::create<float>(cu, f.size(), "TabulatedFunction");
tabulatedFunctions.push_back(array); tabulatedFunctions.push_back(array);
array->upload(f); array->upload(f);
string arrayName = cu.getBondedUtilities().addArgument(array->getDevicePointer(), "float4"); string arrayName = cu.getBondedUtilities().addArgument(array->getDevicePointer(), width == 1 ? "float" : "float"+cu.intToString(width));
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
} }
vector<float4> tabulatedFunctionParamsVec = cu.getExpressionUtilities().computeFunctionParameters(functionList); vector<float4> tabulatedFunctionParamsVec = cu.getExpressionUtilities().computeFunctionParameters(functionList);
...@@ -4309,7 +4320,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con ...@@ -4309,7 +4320,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
forceExpressions["energy += "] = energyExpression; forceExpressions["energy += "] = energyExpression;
compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", functionParamsName); compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp", functionParamsName);
// Finally, apply forces to atoms. // Finally, apply forces to atoms.
...@@ -4331,7 +4342,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con ...@@ -4331,7 +4342,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
if (!isZeroExpression(forceExpressionZ)) if (!isZeroExpression(forceExpressionZ))
expressions[forceName+".z -= "] = forceExpressionZ; expressions[forceName+".z -= "] = forceExpressionZ;
if (expressions.size() > 0) if (expressions.size() > 0)
compute<<cu.getExpressionUtilities().createExpressions(expressions, variables, functionDefinitions, "coordtemp", functionParamsName); compute<<cu.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "coordtemp", functionParamsName);
compute<<"}\n"; compute<<"}\n";
} }
index = 0; index = 0;
...@@ -4941,8 +4952,9 @@ string CudaIntegrateCustomStepKernel::createGlobalComputation(const string& vari ...@@ -4941,8 +4952,9 @@ string CudaIntegrateCustomStepKernel::createGlobalComputation(const string& vari
variables[integrator.getGlobalVariableName(i)] = "globals["+cu.intToString(i)+"]"; variables[integrator.getGlobalVariableName(i)] = "globals["+cu.intToString(i)+"]";
for (int i = 0; i < (int) parameterNames.size(); i++) for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "params["+cu.intToString(i)+"]"; variables[parameterNames[i]] = "params["+cu.intToString(i)+"]";
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
return cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
return cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
} }
string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) { string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) {
...@@ -4978,8 +4990,9 @@ string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& vari ...@@ -4978,8 +4990,9 @@ string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& vari
variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i); variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i);
for (int i = 0; i < (int) parameterNames.size(); i++) for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "params["+cu.intToString(i)+"]"; variables[parameterNames[i]] = "params["+cu.intToString(i)+"]";
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
return cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp"+cu.intToString(component)+"_", "", "double"); vector<pair<string, string> > functionNames;
return cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp"+cu.intToString(component)+"_", "", "double");
} }
void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) { void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -261,7 +261,7 @@ void testPeriodic() { ...@@ -261,7 +261,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
} }
void testTabulatedFunction() { void testContinuous1DFunction() {
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
system.addParticle(1.0); system.addParticle(1.0);
...@@ -272,7 +272,7 @@ void testTabulatedFunction() { ...@@ -272,7 +272,7 @@ void testTabulatedFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0); forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -300,6 +300,33 @@ void testTabulatedFunction() { ...@@ -300,6 +300,33 @@ void testTabulatedFunction() {
} }
} }
void testDiscrete1DFunction() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1)+1");
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i+1, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testCoulombLennardJones() { void testCoulombLennardJones() {
const int numMolecules = 300; const int numMolecules = 300;
const int numParticles = numMolecules*2; const int numParticles = numMolecules*2;
...@@ -725,7 +752,8 @@ int main(int argc, char* argv[]) { ...@@ -725,7 +752,8 @@ int main(int argc, char* argv[]) {
testExclusions(); testExclusions();
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
testTabulatedFunction(); testContinuous1DFunction();
testDiscrete1DFunction();
testCoulombLennardJones(); testCoulombLennardJones();
testParallelComputation(); testParallelComputation();
testSwitchingFunction(); testSwitchingFunction();
......
...@@ -54,33 +54,38 @@ public: ...@@ -54,33 +54,38 @@ public:
* @param expressions the expressions to generate code for (keys are the variables to store the output values in) * @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable that may appear in the expressions. Keys are * @param variables defines the source code to generate for each variable that may appear in the expressions. Keys are
* variable names, and the values are the code to generate for them. * variable names, and the values are the code to generate for them.
* @param functions defines the variable name for each tabulated function that may appear in the expressions * @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables * @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function * @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real") * @param tempType the type of value to use for temporary variables (defaults to "real")
*/ */
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables, std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::string& tempType="real"); const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="real");
/** /**
* Generate the source code for calculating a set of expressions. * Generate the source code for calculating a set of expressions.
* *
* @param expressions the expressions to generate code for (keys are the variables to store the output values in) * @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable or precomputed sub-expression that may appear in the expressions. * @param variables defines the source code to generate for each variable or precomputed sub-expression that may appear in the expressions.
* Each entry is an ExpressionTreeNode, and the code to generate wherever an identical node appears. * Each entry is an ExpressionTreeNode, and the code to generate wherever an identical node appears.
* @param functions defines the variable name for each tabulated function that may appear in the expressions * @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables * @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function * @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "float") * @param tempType the type of value to use for temporary variables (defaults to "float")
*/ */
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables, std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::string& tempType="float"); const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="float");
/** /**
* Calculate the spline coefficients for a tabulated function that appears in expressions. * Calculate the spline coefficients for a tabulated function that appears in expressions.
* *
* @param function the function for which to compute coefficients * @param function the function for which to compute coefficients
* @param width on output, the number of floats used for each value
* @return the spline coefficients * @return the spline coefficients
*/ */
std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function); std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function, int& width);
/** /**
* Given the list of TabulatedFunctions used by a Force, create the parameter array describing them. * Given the list of TabulatedFunctions used by a Force, create the parameter array describing them.
* *
...@@ -92,8 +97,8 @@ public: ...@@ -92,8 +97,8 @@ public:
private: private:
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node, void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType); const std::string& prefix, const std::string& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps); std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode); const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode);
......
...@@ -34,34 +34,37 @@ using namespace Lepton; ...@@ -34,34 +34,37 @@ using namespace Lepton;
using namespace std; using namespace std;
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables, string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) { const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) {
vector<pair<ExpressionTreeNode, string> > variableNodes; vector<pair<ExpressionTreeNode, string> > variableNodes;
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter) for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second)); variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, prefix, functionParams, tempType); return createExpressions(expressions, variableNodes, functions, functionNames, prefix, functionParams, tempType);
} }
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables, string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) { const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) {
stringstream out; stringstream out;
vector<ParsedExpression> allExpressions; vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second); allExpressions.push_back(iter->second);
vector<pair<ExpressionTreeNode, string> > temps = variables; vector<pair<ExpressionTreeNode, string> > temps = variables;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) { for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, functions, prefix, functionParams, allExpressions, tempType); processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n"; out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
} }
return out.str(); return out.str();
} }
void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps, void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const vector<ParsedExpression>& allExpressions, const string& tempType) { const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& functionParams,
const vector<ParsedExpression>& allExpressions, const string& tempType) {
for (int i = 0; i < (int) temps.size(); i++) for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node) if (temps[i].first == node)
return; return;
for (int i = 0; i < (int) node.getChildren().size(); i++) for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, functions, prefix, functionParams, allExpressions, tempType); processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
string name = prefix+context.intToString(temps.size()); string name = prefix+context.intToString(temps.size());
bool hasRecordedNode = false; bool hasRecordedNode = false;
...@@ -75,9 +78,9 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -75,9 +78,9 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
case Operation::CUSTOM: case Operation::CUSTOM:
{ {
int i; int i;
for (i = 0; i < (int) functions.size() && functions[i].first != node.getOperation().getName(); i++) for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
; ;
if (i == functions.size()) if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName()); throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
bool isDeriv = (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 1); bool isDeriv = (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 1);
out << "0.0f;\n"; out << "0.0f;\n";
...@@ -106,20 +109,32 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -106,20 +109,32 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
} }
} }
out << "{\n"; out << "{\n";
out << "float4 params = " << functionParams << "[" << i << "];\n"; if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "if (x >= params.x && x <= params.y) {\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "x = (x-params.x)*params.z;\n"; out << "if (x >= params.x && x <= params.y) {\n";
out << "int index = (int) (floor(x));\n"; out << "x = (x-params.x)*params.z;\n";
out << "index = min(index, (int) params.w);\n"; out << "int index = (int) (floor(x));\n";
out << "float4 coeff = " << functions[i].second << "[index];\n"; out << "index = min(index, (int) params.w);\n";
out << "float b = x-index;\n"; out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "float a = 1.0f-b;\n"; out << "real b = x-index;\n";
if (valueNode != NULL) out << "real a = 1.0f-b;\n";
out << valueName << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n"; if (valueNode != NULL)
if (derivNode != NULL) out << valueName << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n";
out << derivName << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n"; if (derivNode != NULL)
out << "}\n"; out << derivName << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n";
out << "}\n";
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
if (valueNode != NULL) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < params.x) {\n";
out << "int index = (int) round(x);\n";
out << valueName << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
}
}
out << "}"; out << "}";
break; break;
} }
...@@ -341,10 +356,10 @@ void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node ...@@ -341,10 +356,10 @@ void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node
findRelatedPowers(node, searchNode.getChildren()[i], powers); findRelatedPowers(node, searchNode.getChildren()[i], powers);
} }
vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function) { vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) {
// Compute the spline coefficients.
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) { if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
// Compute the spline coefficients.
const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function); const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function);
vector<double> values; vector<double> values;
double min, max; double min, max;
...@@ -361,6 +376,20 @@ vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const Tabul ...@@ -361,6 +376,20 @@ vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const Tabul
f[4*i+2] = (float) (derivs[i]/6.0); f[4*i+2] = (float) (derivs[i]/6.0);
f[4*i+3] = (float) (derivs[i+1]/6.0); f[4*i+3] = (float) (derivs[i+1]/6.0);
} }
width = 4;
return f;
}
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(function);
vector<double> values;
fn.getFunctionParameters(values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f; return f;
} }
throw OpenMMException("computeFunctionCoefficients: Unknown function type"); throw OpenMMException("computeFunctionCoefficients: Unknown function type");
...@@ -376,6 +405,12 @@ vector<mm_float4> OpenCLExpressionUtilities::computeFunctionParameters(const vec ...@@ -376,6 +405,12 @@ vector<mm_float4> OpenCLExpressionUtilities::computeFunctionParameters(const vec
fn.getFunctionParameters(values, min, max); fn.getFunctionParameters(values, min, max);
params[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2); params[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2);
} }
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
vector<double> values;
fn.getFunctionParameters(values);
params[i] = mm_float4((float) values.size(), 0.0f, 0.0f, 0.0f);
}
else else
throw OpenMMException("computeFunctionParameters: Unknown function type"); throw OpenMMException("computeFunctionParameters: Unknown function type");
} }
......
...@@ -614,8 +614,9 @@ void OpenCLCalcCustomBondForceKernel::initialize(const System& system, const Cus ...@@ -614,8 +614,9 @@ void OpenCLCalcCustomBondForceKernel::initialize(const System& system, const Cus
string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::bondForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::bondForce, replacements), force.getForceGroup());
...@@ -843,8 +844,9 @@ void OpenCLCalcCustomAngleForceKernel::initialize(const System& system, const Cu ...@@ -843,8 +844,9 @@ void OpenCLCalcCustomAngleForceKernel::initialize(const System& system, const Cu
string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" angleParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" angleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::angleForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::angleForce, replacements), force.getForceGroup());
...@@ -1247,8 +1249,9 @@ void OpenCLCalcCustomTorsionForceKernel::initialize(const System& system, const ...@@ -1247,8 +1249,9 @@ void OpenCLCalcCustomTorsionForceKernel::initialize(const System& system, const
string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" torsionParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" torsionParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup());
...@@ -1975,10 +1978,11 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -1975,10 +1978,11 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
string arrayName = prefix+"table"+cl.intToString(i); string arrayName = prefix+"table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = &fp;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i)); int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer())); cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
} }
vector<mm_float4> tabulatedFunctionParamsVec = cl.getExpressionUtilities().computeFunctionParameters(functionList); vector<mm_float4> tabulatedFunctionParamsVec = cl.getExpressionUtilities().computeFunctionParameters(functionList);
if (force.getNumFunctions() > 0) { if (force.getNumFunctions() > 0) {
...@@ -2023,7 +2027,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -2023,7 +2027,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
variables.push_back(makeVariable(name, prefix+value)); variables.push_back(makeVariable(name, prefix+value));
} }
stringstream compute; stringstream compute;
compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, prefix+"temp", prefix+"functionParams"); compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, prefix+"temp", prefix+"functionParams");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0"); replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0");
...@@ -2731,10 +2735,11 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2731,10 +2735,11 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
string arrayName = prefix+"table"+cl.intToString(i); string arrayName = prefix+"table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = &fp;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i)); int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer())); cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
tableArgs << ", __global const float4* restrict " << arrayName; tableArgs << ", __global const float4* restrict " << arrayName;
} }
vector<mm_float4> tabulatedFunctionParamsVec = cl.getExpressionUtilities().computeFunctionParameters(functionList); vector<mm_float4> tabulatedFunctionParamsVec = cl.getExpressionUtilities().computeFunctionParameters(functionList);
...@@ -2834,7 +2839,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2834,7 +2839,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize(); Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
n2ValueExpressions["tempValue1 = "] = ex; n2ValueExpressions["tempValue1 = "] = ex;
n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename); n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename);
n2ValueSource << cl.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); n2ValueSource << cl.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
map<string, string> replacements; map<string, string> replacements;
string n2ValueStr = n2ValueSource.str(); string n2ValueStr = n2ValueSource.str();
replacements["COMPUTE_VALUE"] = n2ValueStr; replacements["COMPUTE_VALUE"] = n2ValueStr;
...@@ -2910,7 +2915,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2910,7 +2915,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1); variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1);
map<string, Lepton::ParsedExpression> valueExpressions; map<string, Lepton::ParsedExpression> valueExpressions;
valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize(); valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
reductionSource << cl.getExpressionUtilities().createExpressions(valueExpressions, variables, functionDefinitions, "value"+cl.intToString(i)+"_temp", prefix+"functionParams"); reductionSource << cl.getExpressionUtilities().createExpressions(valueExpressions, variables, functionList, functionDefinitions, "value"+cl.intToString(i)+"_temp", prefix+"functionParams");
} }
for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) { for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
string valueName = "values"+cl.intToString(i+1); string valueName = "values"+cl.intToString(i+1);
...@@ -2974,7 +2979,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -2974,7 +2979,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
} }
if (exclude) if (exclude)
n2EnergySource << "if (!isExcluded) {\n"; n2EnergySource << "if (!isExcluded) {\n";
n2EnergySource << cl.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); n2EnergySource << cl.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
if (exclude) if (exclude)
n2EnergySource << "}\n"; n2EnergySource << "}\n";
} }
...@@ -3145,7 +3150,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3145,7 +3150,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
for (int i = 1; i < force.getNumComputedValues(); i++) for (int i = 1; i < force.getNumComputedValues(); i++)
for (int j = 0; j < i; j++) for (int j = 0; j < i; j++)
expressions["real dV"+cl.intToString(i)+"dV"+cl.intToString(j)+" = "] = valueDerivExpressions[i][j]; expressions["real dV"+cl.intToString(i)+"dV"+cl.intToString(j)+" = "] = valueDerivExpressions[i][j];
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functionDefinitions, "temp", prefix+"functionParams"); compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
// Record values. // Record values.
...@@ -3215,7 +3220,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3215,7 +3220,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
map<string, Lepton::ParsedExpression> derivExpressions; map<string, Lepton::ParsedExpression> derivExpressions;
string js = cl.intToString(j); string js = cl.intToString(j);
derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j]; derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j];
compute << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionDefinitions, "temp_"+is+"_"+js, prefix+"functionParams"); compute << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "temp_"+is+"_"+js, prefix+"functionParams");
compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n"; compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n";
} }
} }
...@@ -3226,7 +3231,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3226,7 +3231,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1]; gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
if (!isZeroExpression(valueGradientExpressions[i][2])) if (!isZeroExpression(valueGradientExpressions[i][2]))
gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2]; gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
compute << cl.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionDefinitions, "temp", prefix+"functionParams"); compute << cl.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
} }
for (int i = 1; i < force.getNumComputedValues(); i++) { for (int i = 1; i < force.getNumComputedValues(); i++) {
string is = cl.intToString(i); string is = cl.intToString(i);
...@@ -3267,7 +3272,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo ...@@ -3267,7 +3272,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize(); Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize();
derivExpressions["real dV0dR1 = "] = dVdR; derivExpressions["real dV0dR1 = "] = dVdR;
derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename); derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename);
chainSource << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp0_", prefix+"functionParams"); chainSource << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, prefix+"temp0_", prefix+"functionParams");
if (needChainForValue[0]) { if (needChainForValue[0]) {
if (useExclusionsForValue) if (useExclusionsForValue)
chainSource << "if (!isExcluded) {\n"; chainSource << "if (!isExcluded) {\n";
...@@ -3677,8 +3682,9 @@ void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const ...@@ -3677,8 +3682,9 @@ void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const
string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType()); string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" particleParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" particleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str(); replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::customExternalForce, replacements), force.getForceGroup()); cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::customExternalForce, replacements), force.getForceGroup());
...@@ -3954,10 +3960,14 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -3954,10 +3960,14 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
string arrayName = "table"+cl.intToString(i); string arrayName = "table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp; functions[name] = &fp;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i)); int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
tableArgs << ", __global const float4* restrict " << arrayName; tableArgs << ", __global const float";
if (width > 1)
tableArgs << width;
tableArgs << "* restrict " << arrayName;
} }
vector<mm_float4> tabulatedFunctionParamsVec = cl.getExpressionUtilities().computeFunctionParameters(functionList); vector<mm_float4> tabulatedFunctionParamsVec = cl.getExpressionUtilities().computeFunctionParameters(functionList);
if (force.getNumFunctions() > 0) { if (force.getNumFunctions() > 0) {
...@@ -4079,9 +4089,9 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu ...@@ -4079,9 +4089,9 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
// Now evaluate the expressions. // Now evaluate the expressions.
computeAcceptor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", "functionParams"); computeAcceptor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp", "functionParams");
forceExpressions["energy += "] = energyExpression; forceExpressions["energy += "] = energyExpression;
computeDonor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", "functionParams"); computeDonor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp", "functionParams");
// Finally, apply forces to atoms. // Finally, apply forces to atoms.
...@@ -4346,11 +4356,12 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -4346,11 +4356,12 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
functionList.push_back(&force.getFunction(i)); functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i); string name = force.getFunctionName(i);
functions[name] = &fp; functions[name] = &fp;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i)); int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
OpenCLArray* array = OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"); OpenCLArray* array = OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction");
tabulatedFunctions.push_back(array); tabulatedFunctions.push_back(array);
array->upload(f); array->upload(f);
string arrayName = cl.getBondedUtilities().addArgument(array->getDeviceBuffer(), "float4"); string arrayName = cl.getBondedUtilities().addArgument(array->getDeviceBuffer(), width == 1 ? "float" : "float"+cl.intToString(width));
functionDefinitions.push_back(make_pair(name, arrayName)); functionDefinitions.push_back(make_pair(name, arrayName));
} }
vector<mm_float4> tabulatedFunctionParamsVec = cl.getExpressionUtilities().computeFunctionParameters(functionList); vector<mm_float4> tabulatedFunctionParamsVec = cl.getExpressionUtilities().computeFunctionParameters(functionList);
...@@ -4474,7 +4485,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -4474,7 +4485,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n"; compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
} }
forceExpressions["energy += "] = energyExpression; forceExpressions["energy += "] = energyExpression;
compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", functionParamsName); compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp", functionParamsName);
// Finally, apply forces to atoms. // Finally, apply forces to atoms.
...@@ -4496,7 +4507,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c ...@@ -4496,7 +4507,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
if (!isZeroExpression(forceExpressionZ)) if (!isZeroExpression(forceExpressionZ))
expressions[forceName+".z -= "] = forceExpressionZ; expressions[forceName+".z -= "] = forceExpressionZ;
if (expressions.size() > 0) if (expressions.size() > 0)
compute<<cl.getExpressionUtilities().createExpressions(expressions, variables, functionDefinitions, "coordtemp", functionParamsName); compute<<cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "coordtemp", functionParamsName);
compute<<"}\n"; compute<<"}\n";
} }
index = 0; index = 0;
...@@ -5186,8 +5197,9 @@ string OpenCLIntegrateCustomStepKernel::createGlobalComputation(const string& va ...@@ -5186,8 +5197,9 @@ string OpenCLIntegrateCustomStepKernel::createGlobalComputation(const string& va
variables[integrator.getGlobalVariableName(i)] = "globals["+cl.intToString(i)+"]"; variables[integrator.getGlobalVariableName(i)] = "globals["+cl.intToString(i)+"]";
for (int i = 0; i < (int) parameterNames.size(); i++) for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "params["+cl.intToString(i)+"]"; variables[parameterNames[i]] = "params["+cl.intToString(i)+"]";
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", ""); vector<pair<string, string> > functionNames;
return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
} }
string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) { string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) {
...@@ -5223,9 +5235,10 @@ string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& va ...@@ -5223,9 +5235,10 @@ string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& va
variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i); variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i);
for (int i = 0; i < (int) parameterNames.size(); i++) for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "params["+cl.intToString(i)+"]"; variables[parameterNames[i]] = "params["+cl.intToString(i)+"]";
vector<pair<string, string> > functions; vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
string tempType = (cl.getSupportsDoublePrecision() ? "double" : "float"); string tempType = (cl.getSupportsDoublePrecision() ? "double" : "float");
return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp"+cl.intToString(component)+"_", "", tempType); return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp"+cl.intToString(component)+"_", "", tempType);
} }
void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) { void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -261,7 +261,7 @@ void testPeriodic() { ...@@ -261,7 +261,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
} }
void testTabulatedFunction() { void testContinuous1DFunction() {
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
system.addParticle(1.0); system.addParticle(1.0);
...@@ -272,7 +272,7 @@ void testTabulatedFunction() { ...@@ -272,7 +272,7 @@ void testTabulatedFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0); forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -300,6 +300,33 @@ void testTabulatedFunction() { ...@@ -300,6 +300,33 @@ void testTabulatedFunction() {
} }
} }
void testDiscrete1DFunction() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1)+1");
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i+1, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testCoulombLennardJones() { void testCoulombLennardJones() {
const int numMolecules = 300; const int numMolecules = 300;
const int numParticles = numMolecules*2; const int numParticles = numMolecules*2;
...@@ -725,7 +752,8 @@ int main(int argc, char* argv[]) { ...@@ -725,7 +752,8 @@ int main(int argc, char* argv[]) {
testExclusions(); testExclusions();
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
testTabulatedFunction(); testContinuous1DFunction();
testDiscrete1DFunction();
testCoulombLennardJones(); testCoulombLennardJones();
testParallelComputation(); testParallelComputation();
testSwitchingFunction(); testSwitchingFunction();
......
...@@ -60,6 +60,21 @@ private: ...@@ -60,6 +60,21 @@ private:
std::vector<double> x, values, derivs; std::vector<double> x, values, derivs;
}; };
/**
* This class adapts a Discrete1DFunction into a Lepton::CustomFunction.
*/
class OPENMM_EXPORT ReferenceDiscrete1DFunction : public Lepton::CustomFunction {
public:
ReferenceDiscrete1DFunction(const Discrete1DFunction& function);
int getNumArguments() const;
double evaluate(const double* arguments) const;
double evaluateDerivative(const double* arguments, const int* derivOrder) const;
CustomFunction* clone() const;
private:
const Discrete1DFunction& function;
std::vector<double> values;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_REFERENCETABULATEDFUNCTION_H_*/ #endif /*OPENMM_REFERENCETABULATEDFUNCTION_H_*/
...@@ -40,6 +40,8 @@ using Lepton::CustomFunction; ...@@ -40,6 +40,8 @@ using Lepton::CustomFunction;
extern "C" CustomFunction* createReferenceTabulatedFunction(const TabulatedFunction& function) { extern "C" CustomFunction* createReferenceTabulatedFunction(const TabulatedFunction& function) {
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
return new ReferenceContinuous1DFunction(dynamic_cast<const Continuous1DFunction&>(function)); return new ReferenceContinuous1DFunction(dynamic_cast<const Continuous1DFunction&>(function));
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
return new ReferenceDiscrete1DFunction(dynamic_cast<const Discrete1DFunction&>(function));
throw OpenMMException("createReferenceTabulatedFunction: Unknown function type"); throw OpenMMException("createReferenceTabulatedFunction: Unknown function type");
} }
...@@ -73,3 +75,26 @@ double ReferenceContinuous1DFunction::evaluateDerivative(const double* arguments ...@@ -73,3 +75,26 @@ double ReferenceContinuous1DFunction::evaluateDerivative(const double* arguments
CustomFunction* ReferenceContinuous1DFunction::clone() const { CustomFunction* ReferenceContinuous1DFunction::clone() const {
return new ReferenceContinuous1DFunction(function); return new ReferenceContinuous1DFunction(function);
} }
ReferenceDiscrete1DFunction::ReferenceDiscrete1DFunction(const Discrete1DFunction& function) : function(function) {
function.getFunctionParameters(values);
}
int ReferenceDiscrete1DFunction::getNumArguments() const {
return 1;
}
double ReferenceDiscrete1DFunction::evaluate(const double* arguments) const {
int t = (int) arguments[0];
if (t < 0 || t >= values.size())
throw OpenMMException("ReferenceDiscrete1DFunction: argument out of range");
return values[t];
}
double ReferenceDiscrete1DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* ReferenceDiscrete1DFunction::clone() const {
return new ReferenceDiscrete1DFunction(function);
}
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. * * Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -227,7 +227,7 @@ void testPeriodic() { ...@@ -227,7 +227,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL); ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
} }
void testTabulatedFunction() { void testContinuous1DFunction() {
ReferencePlatform platform; ReferencePlatform platform;
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
...@@ -239,7 +239,7 @@ void testTabulatedFunction() { ...@@ -239,7 +239,7 @@ void testTabulatedFunction() {
vector<double> table; vector<double> table;
for (int i = 0; i < 21; i++) for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i)); table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0); forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField); system.addForce(forceField);
Context context(system, integrator, platform); Context context(system, integrator, platform);
vector<Vec3> positions(2); vector<Vec3> positions(2);
...@@ -267,6 +267,34 @@ void testTabulatedFunction() { ...@@ -267,6 +267,34 @@ void testTabulatedFunction() {
} }
} }
void testDiscrete1DFunction() {
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r)+1");
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL(table[i]+1.0, state.getPotentialEnergy());
}
}
void testCoulombLennardJones() { void testCoulombLennardJones() {
const int numMolecules = 300; const int numMolecules = 300;
const int numParticles = numMolecules*2; const int numParticles = numMolecules*2;
...@@ -658,7 +686,8 @@ int main() { ...@@ -658,7 +686,8 @@ int main() {
testExclusions(); testExclusions();
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
testTabulatedFunction(); testContinuous1DFunction();
testDiscrete1DFunction();
testCoulombLennardJones(); testCoulombLennardJones();
testSwitchingFunction(); testSwitchingFunction();
testLongRangeCorrection(); testLongRangeCorrection();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment