Commit 56e36449 authored by peastman's avatar peastman
Browse files

Created Discrete1DFunction

parent 81f26683
......@@ -101,6 +101,36 @@ private:
double min, max;
};
/**
* This is a TabulatedFunction that computes a discrete one dimensional function.
*/
class OPENMM_EXPORT Discrete1DFunction : public TabulatedFunction {
public:
/**
* Create a Discrete1DFunction f(x) based on a set of tabulated values.
*
* @param values the tabulated values of the function f(x). The function is only defined
* for integer values of x in the range [0, values.size()].
*/
Discrete1DFunction(const std::vector<double>& values);
/**
* Get the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x). The function is only defined
* for integer values of x in the range [0, values.size()].
*/
void getFunctionParameters(std::vector<double>& values) const;
/**
* Set the parameters for the tabulated function.
*
* @param values the tabulated values of the function f(x). The function is only defined
* for integer values of x in the range [0, values.size()].
*/
void setFunctionParameters(const std::vector<double>& values);
private:
std::vector<double> values;
};
} // namespace OpenMM
#endif /*OPENMM_TABULATEDFUNCTION_H_*/
......@@ -60,3 +60,15 @@ void Continuous1DFunction::setFunctionParameters(const std::vector<double>& valu
this->min = min;
this->max = max;
}
Discrete1DFunction::Discrete1DFunction(const std::vector<double>& values) {
this->values = values;
}
void Discrete1DFunction::getFunctionParameters(std::vector<double>& values) const {
values = this->values;
}
void Discrete1DFunction::setFunctionParameters(const std::vector<double>& values) {
this->values = values;
}
......@@ -54,33 +54,38 @@ public:
* @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable that may appear in the expressions. Keys are
* variable names, and the values are the code to generate for them.
* @param functions defines the variable name for each tabulated function that may appear in the expressions
* @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real")
*/
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::string& tempType="real");
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="real");
/**
* Generate the source code for calculating a set of expressions.
*
* @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable or precomputed sub-expression that may appear in the expressions.
* Each entry is an ExpressionTreeNode, and the code to generate wherever an identical node appears.
* @param functions defines the variable name for each tabulated function that may appear in the expressions
* @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real")
*/
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::string& tempType="real");
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="real");
/**
* Calculate the spline coefficients for a tabulated function that appears in expressions.
*
* @param function the function for which to compute coefficients
* @param width on output, the number of floats used for each value
* @return the spline coefficients
*/
std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function);
std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function, int& width);
/**
* Given the list of TabulatedFunctions used by a Force, create the parameter array describing them.
*
......@@ -92,8 +97,8 @@ public:
private:
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams,
const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode);
......
......@@ -34,34 +34,37 @@ using namespace Lepton;
using namespace std;
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) {
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) {
vector<pair<ExpressionTreeNode, string> > variableNodes;
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, prefix, functionParams, tempType);
return createExpressions(expressions, variableNodes, functions, functionNames, prefix, functionParams, tempType);
}
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) {
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) {
stringstream out;
vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second);
vector<pair<ExpressionTreeNode, string> > temps = variables;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, functions, prefix, functionParams, allExpressions, tempType);
processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
}
return out.str();
}
void CudaExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const vector<ParsedExpression>& allExpressions, const string& tempType) {
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& functionParams,
const vector<ParsedExpression>& allExpressions, const string& tempType) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return;
for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, functions, prefix, functionParams, allExpressions, tempType);
processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
string name = prefix+context.intToString(temps.size());
bool hasRecordedNode = false;
......@@ -75,9 +78,9 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
case Operation::CUSTOM:
{
int i;
for (i = 0; i < (int) functions.size() && functions[i].first != node.getOperation().getName(); i++)
for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
;
if (i == functions.size())
if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
bool isDeriv = (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 1);
out << "0.0f;\n";
......@@ -106,13 +109,14 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
}
}
out << "{\n";
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n";
out << "x = (x-params.x)*params.z;\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) params.w);\n";
out << "float4 coeff = " << functions[i].second << "[index];\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n";
if (valueNode != NULL)
......@@ -120,6 +124,17 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
if (derivNode != NULL)
out << derivName << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n";
out << "}\n";
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
if (valueNode != NULL) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < params.x) {\n";
out << "int index = (int) round(x);\n";
out << valueName << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
}
}
out << "}";
break;
}
......@@ -341,10 +356,10 @@ void CudaExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node,
findRelatedPowers(node, searchNode.getChildren()[i], powers);
}
vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function) {
vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) {
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
// Compute the spline coefficients.
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function);
vector<double> values;
double min, max;
......@@ -361,6 +376,20 @@ vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const Tabulat
f[4*i+2] = (float) (derivs[i]/6.0);
f[4*i+3] = (float) (derivs[i+1]/6.0);
}
width = 4;
return f;
}
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(function);
vector<double> values;
fn.getFunctionParameters(values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
throw OpenMMException("computeFunctionCoefficients: Unknown function type");
......@@ -376,6 +405,12 @@ vector<float4> CudaExpressionUtilities::computeFunctionParameters(const vector<c
fn.getFunctionParameters(values, min, max);
params[i] = make_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2);
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
vector<double> values;
fn.getFunctionParameters(values);
params[i] = make_float4((float) values.size(), 0.0f, 0.0f, 0.0f);
}
else
throw OpenMMException("computeFunctionParameters: Unknown function type");
}
......
......@@ -595,8 +595,9 @@ void CudaCalcCustomBondForceKernel::initialize(const System& system, const Custo
string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
}
vector<pair<string, string> > functions;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", "");
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str();
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::bondForce, replacements), force.getForceGroup());
......@@ -830,8 +831,9 @@ void CudaCalcCustomAngleForceKernel::initialize(const System& system, const Cust
string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" angleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
}
vector<pair<string, string> > functions;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", "");
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str();
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::angleForce, replacements), force.getForceGroup());
......@@ -1253,8 +1255,9 @@ void CudaCalcCustomTorsionForceKernel::initialize(const System& system, const Cu
string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" torsionParams"<<(i+1)<<" = "<<argName<<"[index];\n";
}
vector<pair<string, string> > functions;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", "");
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str();
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::torsionForce, replacements), force.getForceGroup());
......@@ -1965,10 +1968,11 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
string arrayName = prefix+"table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i));
int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
}
vector<float4> tabulatedFunctionParamsVec = cu.getExpressionUtilities().computeFunctionParameters(functionList);
if (force.getNumFunctions() > 0) {
......@@ -2013,7 +2017,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
variables.push_back(makeVariable(name, prefix+value));
}
stringstream compute;
compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, prefix+"temp", prefix+"functionParams");
compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, prefix+"temp", prefix+"functionParams");
map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str();
replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0");
......@@ -2678,10 +2682,11 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
string arrayName = prefix+"table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i));
int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
tableArgs << ", const float4* __restrict__ " << arrayName;
}
vector<float4> tabulatedFunctionParamsVec = cu.getExpressionUtilities().computeFunctionParameters(functionList);
......@@ -2775,7 +2780,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
n2ValueExpressions["tempValue1 = "] = ex;
n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename);
n2ValueSource << cu.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionDefinitions, "temp", prefix+"functionParams");
n2ValueSource << cu.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
map<string, string> replacements;
string n2ValueStr = n2ValueSource.str();
replacements["COMPUTE_VALUE"] = n2ValueStr;
......@@ -2853,7 +2858,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1);
map<string, Lepton::ParsedExpression> valueExpressions;
valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
reductionSource << cu.getExpressionUtilities().createExpressions(valueExpressions, variables, functionDefinitions, "value"+cu.intToString(i)+"_temp", prefix+"functionParams");
reductionSource << cu.getExpressionUtilities().createExpressions(valueExpressions, variables, functionList, functionDefinitions, "value"+cu.intToString(i)+"_temp", prefix+"functionParams");
}
for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
string valueName = "values"+cu.intToString(i+1);
......@@ -2907,7 +2912,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
}
if (exclude)
n2EnergySource << "if (!isExcluded) {\n";
n2EnergySource << cu.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionDefinitions, "temp", prefix+"functionParams");
n2EnergySource << cu.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
if (exclude)
n2EnergySource << "}\n";
}
......@@ -3056,7 +3061,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
for (int i = 1; i < force.getNumComputedValues(); i++)
for (int j = 0; j < i; j++)
expressions["real dV"+cu.intToString(i)+"dV"+cu.intToString(j)+" = "] = valueDerivExpressions[i][j];
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functionDefinitions, "temp", prefix+"functionParams");
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
// Record values.
......@@ -3128,7 +3133,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
map<string, Lepton::ParsedExpression> derivExpressions;
string js = cu.intToString(j);
derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j];
compute << cu.getExpressionUtilities().createExpressions(derivExpressions, variables, functionDefinitions, "temp_"+is+"_"+js, prefix+"functionParams");
compute << cu.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "temp_"+is+"_"+js, prefix+"functionParams");
compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n";
}
}
......@@ -3139,7 +3144,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
if (!isZeroExpression(valueGradientExpressions[i][2]))
gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
compute << cu.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionDefinitions, "temp", prefix+"functionParams");
compute << cu.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
}
for (int i = 1; i < force.getNumComputedValues(); i++) {
string is = cu.intToString(i);
......@@ -3181,7 +3186,7 @@ void CudaCalcCustomGBForceKernel::initialize(const System& system, const CustomG
Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize();
derivExpressions["real dV0dR1 = "] = dVdR;
derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename);
chainSource << cu.getExpressionUtilities().createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp0_", prefix+"functionParams");
chainSource << cu.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, prefix+"temp0_", prefix+"functionParams");
if (needChainForValue[0]) {
if (useExclusionsForValue)
chainSource << "if (!isExcluded) {\n";
......@@ -3539,8 +3544,9 @@ void CudaCalcCustomExternalForceKernel::initialize(const System& system, const C
string argName = cu.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" particleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
}
vector<pair<string, string> > functions;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", "");
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
compute << cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str();
cu.getBondedUtilities().addInteraction(atoms, cu.replaceStrings(CudaKernelSources::customExternalForce, replacements), force.getForceGroup());
......@@ -3791,10 +3797,14 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
string arrayName = "table"+cu.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i));
int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
tableArgs << ", const float4* __restrict__ " << arrayName;
tableArgs << ", const float";
if (width > 1)
tableArgs << width;
tableArgs << "* __restrict__ " << arrayName;
}
vector<float4> tabulatedFunctionParamsVec = cu.getExpressionUtilities().computeFunctionParameters(functionList);
if (force.getNumFunctions() > 0) {
......@@ -3916,9 +3926,9 @@ void CudaCalcCustomHbondForceKernel::initialize(const System& system, const Cust
// Now evaluate the expressions.
computeAcceptor << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", "functionParams");
computeAcceptor << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp", "functionParams");
forceExpressions["energy += "] = energyExpression;
computeDonor << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", "functionParams");
computeDonor << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp", "functionParams");
// Finally, apply forces to atoms.
......@@ -4181,11 +4191,12 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i);
functions[name] = &fp;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i));
int width;
vector<float> f = cu.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
CudaArray* array = CudaArray::create<float>(cu, f.size(), "TabulatedFunction");
tabulatedFunctions.push_back(array);
array->upload(f);
string arrayName = cu.getBondedUtilities().addArgument(array->getDevicePointer(), "float4");
string arrayName = cu.getBondedUtilities().addArgument(array->getDevicePointer(), width == 1 ? "float" : "float"+cu.intToString(width));
functionDefinitions.push_back(make_pair(name, arrayName));
}
vector<float4> tabulatedFunctionParamsVec = cu.getExpressionUtilities().computeFunctionParameters(functionList);
......@@ -4309,7 +4320,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
}
forceExpressions["energy += "] = energyExpression;
compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", functionParamsName);
compute << cu.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp", functionParamsName);
// Finally, apply forces to atoms.
......@@ -4331,7 +4342,7 @@ void CudaCalcCustomCompoundBondForceKernel::initialize(const System& system, con
if (!isZeroExpression(forceExpressionZ))
expressions[forceName+".z -= "] = forceExpressionZ;
if (expressions.size() > 0)
compute<<cu.getExpressionUtilities().createExpressions(expressions, variables, functionDefinitions, "coordtemp", functionParamsName);
compute<<cu.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "coordtemp", functionParamsName);
compute<<"}\n";
}
index = 0;
......@@ -4941,8 +4952,9 @@ string CudaIntegrateCustomStepKernel::createGlobalComputation(const string& vari
variables[integrator.getGlobalVariableName(i)] = "globals["+cu.intToString(i)+"]";
for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "params["+cu.intToString(i)+"]";
vector<pair<string, string> > functions;
return cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", "");
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
return cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
}
string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) {
......@@ -4978,8 +4990,9 @@ string CudaIntegrateCustomStepKernel::createPerDofComputation(const string& vari
variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i);
for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "params["+cu.intToString(i)+"]";
vector<pair<string, string> > functions;
return cu.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp"+cu.intToString(component)+"_", "", "double");
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
return cu.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp"+cu.intToString(component)+"_", "", "double");
}
void CudaIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
......
......@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -261,7 +261,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
}
void testTabulatedFunction() {
void testContinuous1DFunction() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
......@@ -272,7 +272,7 @@ void testTabulatedFunction() {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0);
forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -300,6 +300,33 @@ void testTabulatedFunction() {
}
}
void testDiscrete1DFunction() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1)+1");
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i+1, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testCoulombLennardJones() {
const int numMolecules = 300;
const int numParticles = numMolecules*2;
......@@ -725,7 +752,8 @@ int main(int argc, char* argv[]) {
testExclusions();
testCutoff();
testPeriodic();
testTabulatedFunction();
testContinuous1DFunction();
testDiscrete1DFunction();
testCoulombLennardJones();
testParallelComputation();
testSwitchingFunction();
......
......@@ -54,33 +54,38 @@ public:
* @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable that may appear in the expressions. Keys are
* variable names, and the values are the code to generate for them.
* @param functions defines the variable name for each tabulated function that may appear in the expressions
* @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real")
*/
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::string& tempType="real");
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="real");
/**
* Generate the source code for calculating a set of expressions.
*
* @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable or precomputed sub-expression that may appear in the expressions.
* Each entry is an ExpressionTreeNode, and the code to generate wherever an identical node appears.
* @param functions defines the variable name for each tabulated function that may appear in the expressions
* @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "float")
*/
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams, const std::string& tempType="float");
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="float");
/**
* Calculate the spline coefficients for a tabulated function that appears in expressions.
*
* @param function the function for which to compute coefficients
* @param width on output, the number of floats used for each value
* @return the spline coefficients
*/
std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function);
std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function, int& width);
/**
* Given the list of TabulatedFunctions used by a Force, create the parameter array describing them.
*
......@@ -92,8 +97,8 @@ public:
private:
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams,
const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode);
......
......@@ -34,34 +34,37 @@ using namespace Lepton;
using namespace std;
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) {
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) {
vector<pair<ExpressionTreeNode, string> > variableNodes;
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, prefix, functionParams, tempType);
return createExpressions(expressions, variableNodes, functions, functionNames, prefix, functionParams, tempType);
}
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const string& tempType) {
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) {
stringstream out;
vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second);
vector<pair<ExpressionTreeNode, string> > temps = variables;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, functions, prefix, functionParams, allExpressions, tempType);
processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
}
return out.str();
}
void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams, const vector<ParsedExpression>& allExpressions, const string& tempType) {
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& functionParams,
const vector<ParsedExpression>& allExpressions, const string& tempType) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return;
for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, functions, prefix, functionParams, allExpressions, tempType);
processExpression(out, node.getChildren()[i], temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
string name = prefix+context.intToString(temps.size());
bool hasRecordedNode = false;
......@@ -75,9 +78,9 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
case Operation::CUSTOM:
{
int i;
for (i = 0; i < (int) functions.size() && functions[i].first != node.getOperation().getName(); i++)
for (i = 0; i < (int) functionNames.size() && functionNames[i].first != node.getOperation().getName(); i++)
;
if (i == functions.size())
if (i == functionNames.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
bool isDeriv = (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 1);
out << "0.0f;\n";
......@@ -106,20 +109,32 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
}
}
out << "{\n";
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n";
out << "x = (x-params.x)*params.z;\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) params.w);\n";
out << "float4 coeff = " << functions[i].second << "[index];\n";
out << "float b = x-index;\n";
out << "float a = 1.0f-b;\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n";
if (valueNode != NULL)
out << valueName << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n";
if (derivNode != NULL)
out << derivName << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n";
out << "}\n";
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
if (valueNode != NULL) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < params.x) {\n";
out << "int index = (int) round(x);\n";
out << valueName << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
}
}
out << "}";
break;
}
......@@ -341,10 +356,10 @@ void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node
findRelatedPowers(node, searchNode.getChildren()[i], powers);
}
vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function) {
vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const TabulatedFunction& function, int& width) {
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
// Compute the spline coefficients.
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL) {
const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(function);
vector<double> values;
double min, max;
......@@ -361,6 +376,20 @@ vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const Tabul
f[4*i+2] = (float) (derivs[i]/6.0);
f[4*i+3] = (float) (derivs[i+1]/6.0);
}
width = 4;
return f;
}
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL) {
// Record the tabulated values.
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(function);
vector<double> values;
fn.getFunctionParameters(values);
int numValues = values.size();
vector<float> f(numValues);
for (int i = 0; i < numValues; i++)
f[i] = (float) values[i];
width = 1;
return f;
}
throw OpenMMException("computeFunctionCoefficients: Unknown function type");
......@@ -376,6 +405,12 @@ vector<mm_float4> OpenCLExpressionUtilities::computeFunctionParameters(const vec
fn.getFunctionParameters(values, min, max);
params[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2);
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
vector<double> values;
fn.getFunctionParameters(values);
params[i] = mm_float4((float) values.size(), 0.0f, 0.0f, 0.0f);
}
else
throw OpenMMException("computeFunctionParameters: Unknown function type");
}
......
......@@ -614,8 +614,9 @@ void OpenCLCalcCustomBondForceKernel::initialize(const System& system, const Cus
string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
}
vector<pair<string, string> > functions;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", "");
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::bondForce, replacements), force.getForceGroup());
......@@ -843,8 +844,9 @@ void OpenCLCalcCustomAngleForceKernel::initialize(const System& system, const Cu
string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" angleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
}
vector<pair<string, string> > functions;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", "");
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::angleForce, replacements), force.getForceGroup());
......@@ -1247,8 +1249,9 @@ void OpenCLCalcCustomTorsionForceKernel::initialize(const System& system, const
string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" torsionParams"<<(i+1)<<" = "<<argName<<"[index];\n";
}
vector<pair<string, string> > functions;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", "");
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup());
......@@ -1975,10 +1978,11 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
string arrayName = prefix+"table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i));
int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
}
vector<mm_float4> tabulatedFunctionParamsVec = cl.getExpressionUtilities().computeFunctionParameters(functionList);
if (force.getNumFunctions() > 0) {
......@@ -2023,7 +2027,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
variables.push_back(makeVariable(name, prefix+value));
}
stringstream compute;
compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, prefix+"temp", prefix+"functionParams");
compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, prefix+"temp", prefix+"functionParams");
map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str();
replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0");
......@@ -2731,10 +2735,11 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
string arrayName = prefix+"table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i));
int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", 4, sizeof(cl_float4), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
tableArgs << ", __global const float4* restrict " << arrayName;
}
vector<mm_float4> tabulatedFunctionParamsVec = cl.getExpressionUtilities().computeFunctionParameters(functionList);
......@@ -2834,7 +2839,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
n2ValueExpressions["tempValue1 = "] = ex;
n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename);
n2ValueSource << cl.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionDefinitions, "temp", prefix+"functionParams");
n2ValueSource << cl.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
map<string, string> replacements;
string n2ValueStr = n2ValueSource.str();
replacements["COMPUTE_VALUE"] = n2ValueStr;
......@@ -2910,7 +2915,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1);
map<string, Lepton::ParsedExpression> valueExpressions;
valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
reductionSource << cl.getExpressionUtilities().createExpressions(valueExpressions, variables, functionDefinitions, "value"+cl.intToString(i)+"_temp", prefix+"functionParams");
reductionSource << cl.getExpressionUtilities().createExpressions(valueExpressions, variables, functionList, functionDefinitions, "value"+cl.intToString(i)+"_temp", prefix+"functionParams");
}
for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
string valueName = "values"+cl.intToString(i+1);
......@@ -2974,7 +2979,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
}
if (exclude)
n2EnergySource << "if (!isExcluded) {\n";
n2EnergySource << cl.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionDefinitions, "temp", prefix+"functionParams");
n2EnergySource << cl.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
if (exclude)
n2EnergySource << "}\n";
}
......@@ -3145,7 +3150,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
for (int i = 1; i < force.getNumComputedValues(); i++)
for (int j = 0; j < i; j++)
expressions["real dV"+cl.intToString(i)+"dV"+cl.intToString(j)+" = "] = valueDerivExpressions[i][j];
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functionDefinitions, "temp", prefix+"functionParams");
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
// Record values.
......@@ -3215,7 +3220,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
map<string, Lepton::ParsedExpression> derivExpressions;
string js = cl.intToString(j);
derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j];
compute << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionDefinitions, "temp_"+is+"_"+js, prefix+"functionParams");
compute << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "temp_"+is+"_"+js, prefix+"functionParams");
compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n";
}
}
......@@ -3226,7 +3231,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
if (!isZeroExpression(valueGradientExpressions[i][2]))
gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
compute << cl.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionDefinitions, "temp", prefix+"functionParams");
compute << cl.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "temp", prefix+"functionParams");
}
for (int i = 1; i < force.getNumComputedValues(); i++) {
string is = cl.intToString(i);
......@@ -3267,7 +3272,7 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize();
derivExpressions["real dV0dR1 = "] = dVdR;
derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename);
chainSource << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp0_", prefix+"functionParams");
chainSource << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, prefix+"temp0_", prefix+"functionParams");
if (needChainForValue[0]) {
if (useExclusionsForValue)
chainSource << "if (!isExcluded) {\n";
......@@ -3677,8 +3682,9 @@ void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const
string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
compute<<buffer.getType()<<" particleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
}
vector<pair<string, string> > functions;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", "");
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
map<string, string> replacements;
replacements["COMPUTE_FORCE"] = compute.str();
cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::customExternalForce, replacements), force.getForceGroup());
......@@ -3954,10 +3960,14 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
string arrayName = "table"+cl.intToString(i);
functionDefinitions.push_back(make_pair(name, arrayName));
functions[name] = &fp;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i));
int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
tableArgs << ", __global const float4* restrict " << arrayName;
tableArgs << ", __global const float";
if (width > 1)
tableArgs << width;
tableArgs << "* restrict " << arrayName;
}
vector<mm_float4> tabulatedFunctionParamsVec = cl.getExpressionUtilities().computeFunctionParameters(functionList);
if (force.getNumFunctions() > 0) {
......@@ -4079,9 +4089,9 @@ void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const Cu
// Now evaluate the expressions.
computeAcceptor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", "functionParams");
computeAcceptor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp", "functionParams");
forceExpressions["energy += "] = energyExpression;
computeDonor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", "functionParams");
computeDonor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp", "functionParams");
// Finally, apply forces to atoms.
......@@ -4346,11 +4356,12 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
functionList.push_back(&force.getFunction(i));
string name = force.getFunctionName(i);
functions[name] = &fp;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i));
int width;
vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getFunction(i), width);
OpenCLArray* array = OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction");
tabulatedFunctions.push_back(array);
array->upload(f);
string arrayName = cl.getBondedUtilities().addArgument(array->getDeviceBuffer(), "float4");
string arrayName = cl.getBondedUtilities().addArgument(array->getDeviceBuffer(), width == 1 ? "float" : "float"+cl.intToString(width));
functionDefinitions.push_back(make_pair(name, arrayName));
}
vector<mm_float4> tabulatedFunctionParamsVec = cl.getExpressionUtilities().computeFunctionParameters(functionList);
......@@ -4474,7 +4485,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
}
forceExpressions["energy += "] = energyExpression;
compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionDefinitions, "temp", functionParamsName);
compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp", functionParamsName);
// Finally, apply forces to atoms.
......@@ -4496,7 +4507,7 @@ void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, c
if (!isZeroExpression(forceExpressionZ))
expressions[forceName+".z -= "] = forceExpressionZ;
if (expressions.size() > 0)
compute<<cl.getExpressionUtilities().createExpressions(expressions, variables, functionDefinitions, "coordtemp", functionParamsName);
compute<<cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "coordtemp", functionParamsName);
compute<<"}\n";
}
index = 0;
......@@ -5186,8 +5197,9 @@ string OpenCLIntegrateCustomStepKernel::createGlobalComputation(const string& va
variables[integrator.getGlobalVariableName(i)] = "globals["+cl.intToString(i)+"]";
for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "params["+cl.intToString(i)+"]";
vector<pair<string, string> > functions;
return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp", "");
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp", "");
}
string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) {
......@@ -5223,9 +5235,10 @@ string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& va
variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i);
for (int i = 0; i < (int) parameterNames.size(); i++)
variables[parameterNames[i]] = "params["+cl.intToString(i)+"]";
vector<pair<string, string> > functions;
vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames;
string tempType = (cl.getSupportsDoublePrecision() ? "double" : "float");
return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, "temp"+cl.intToString(component)+"_", "", tempType);
return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp"+cl.intToString(component)+"_", "", tempType);
}
void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
......
......@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -261,7 +261,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
}
void testTabulatedFunction() {
void testContinuous1DFunction() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
......@@ -272,7 +272,7 @@ void testTabulatedFunction() {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0);
forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -300,6 +300,33 @@ void testTabulatedFunction() {
}
}
void testDiscrete1DFunction() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1)+1");
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i+1, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testCoulombLennardJones() {
const int numMolecules = 300;
const int numParticles = numMolecules*2;
......@@ -725,7 +752,8 @@ int main(int argc, char* argv[]) {
testExclusions();
testCutoff();
testPeriodic();
testTabulatedFunction();
testContinuous1DFunction();
testDiscrete1DFunction();
testCoulombLennardJones();
testParallelComputation();
testSwitchingFunction();
......
......@@ -60,6 +60,21 @@ private:
std::vector<double> x, values, derivs;
};
/**
* This class adapts a Discrete1DFunction into a Lepton::CustomFunction.
*/
class OPENMM_EXPORT ReferenceDiscrete1DFunction : public Lepton::CustomFunction {
public:
ReferenceDiscrete1DFunction(const Discrete1DFunction& function);
int getNumArguments() const;
double evaluate(const double* arguments) const;
double evaluateDerivative(const double* arguments, const int* derivOrder) const;
CustomFunction* clone() const;
private:
const Discrete1DFunction& function;
std::vector<double> values;
};
} // namespace OpenMM
#endif /*OPENMM_REFERENCETABULATEDFUNCTION_H_*/
......@@ -40,6 +40,8 @@ using Lepton::CustomFunction;
extern "C" CustomFunction* createReferenceTabulatedFunction(const TabulatedFunction& function) {
if (dynamic_cast<const Continuous1DFunction*>(&function) != NULL)
return new ReferenceContinuous1DFunction(dynamic_cast<const Continuous1DFunction&>(function));
if (dynamic_cast<const Discrete1DFunction*>(&function) != NULL)
return new ReferenceDiscrete1DFunction(dynamic_cast<const Discrete1DFunction&>(function));
throw OpenMMException("createReferenceTabulatedFunction: Unknown function type");
}
......@@ -73,3 +75,26 @@ double ReferenceContinuous1DFunction::evaluateDerivative(const double* arguments
CustomFunction* ReferenceContinuous1DFunction::clone() const {
return new ReferenceContinuous1DFunction(function);
}
ReferenceDiscrete1DFunction::ReferenceDiscrete1DFunction(const Discrete1DFunction& function) : function(function) {
function.getFunctionParameters(values);
}
int ReferenceDiscrete1DFunction::getNumArguments() const {
return 1;
}
double ReferenceDiscrete1DFunction::evaluate(const double* arguments) const {
int t = (int) arguments[0];
if (t < 0 || t >= values.size())
throw OpenMMException("ReferenceDiscrete1DFunction: argument out of range");
return values[t];
}
double ReferenceDiscrete1DFunction::evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0.0;
}
CustomFunction* ReferenceDiscrete1DFunction::clone() const {
return new ReferenceDiscrete1DFunction(function);
}
......@@ -7,7 +7,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2013 Stanford University and the Authors. *
* Portions copyright (c) 2008-2014 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -227,7 +227,7 @@ void testPeriodic() {
ASSERT_EQUAL_TOL(1.9+1+0.9, state.getPotentialEnergy(), TOL);
}
void testTabulatedFunction() {
void testContinuous1DFunction() {
ReferencePlatform platform;
System system;
system.addParticle(1.0);
......@@ -239,7 +239,7 @@ void testTabulatedFunction() {
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", table, 1.0, 6.0);
forceField->addFunction("fn", new Continuous1DFunction(table, 1.0, 6.0));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
......@@ -267,6 +267,34 @@ void testTabulatedFunction() {
}
}
void testDiscrete1DFunction() {
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r)+1");
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(std::sin(0.25*i));
forceField->addFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL(table[i]+1.0, state.getPotentialEnergy());
}
}
void testCoulombLennardJones() {
const int numMolecules = 300;
const int numParticles = numMolecules*2;
......@@ -658,7 +686,8 @@ int main() {
testExclusions();
testCutoff();
testPeriodic();
testTabulatedFunction();
testContinuous1DFunction();
testDiscrete1DFunction();
testCoulombLennardJones();
testSwitchingFunction();
testLongRangeCorrection();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment