Commit 0a1a011d authored by peastman's avatar peastman
Browse files

Tabulated function parameters are hardcoded in the kernel instead of being...

Tabulated function parameters are hardcoded in the kernel instead of being stored in an array.  This makes the code simpler and may help performance slightly.
parent 5d4dcb42
......@@ -56,12 +56,11 @@ public:
* @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real")
*/
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="real");
const std::string& prefix, const std::string& tempType="real");
/**
* Generate the source code for calculating a set of expressions.
*
......@@ -71,12 +70,11 @@ public:
* @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real")
*/
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="real");
const std::string& prefix, const std::string& tempType="real");
/**
* Calculate the spline coefficients for a tabulated function that appears in expressions.
*
......@@ -85,13 +83,6 @@ public:
* @return the spline coefficients
*/
std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function, int& width);
/**
* Given the list of TabulatedFunctions used by a Force, create the parameter array describing them.
*
* @param functions the list of functions to include in the array
* @return the parameter array
*/
std::vector<float4> computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
/**
* Get a Lepton::CustomFunction that can be used to represent a TabulatedFunction when parsing expressions.
*
......@@ -121,12 +112,13 @@ private:
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
CudaContext& context;
FunctionPlaceholder fp1, fp2, fp3;
};
......
......@@ -638,7 +638,7 @@ private:
class CudaCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel {
public:
CudaCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomNonbondedForceKernel(name, platform),
cu(cu), params(NULL), globals(NULL), tabulatedFunctionParams(NULL), interactionGroupData(NULL), forceCopy(NULL), system(system), hasInitializedKernel(false) {
cu(cu), params(NULL), globals(NULL), interactionGroupData(NULL), forceCopy(NULL), system(system), hasInitializedKernel(false) {
}
~CudaCalcCustomNonbondedForceKernel();
/**
......@@ -669,7 +669,6 @@ private:
CudaContext& cu;
CudaParameterSet* params;
CudaArray* globals;
CudaArray* tabulatedFunctionParams;
CudaArray* interactionGroupData;
CUfunction interactionGroupKernel;
std::vector<void*> interactionGroupArgs;
......@@ -739,7 +738,7 @@ class CudaCalcCustomGBForceKernel : public CalcCustomGBForceKernel {
public:
CudaCalcCustomGBForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomGBForceKernel(name, platform),
hasInitializedKernels(false), cu(cu), params(NULL), computedValues(NULL), energyDerivs(NULL), energyDerivChain(NULL), longEnergyDerivs(NULL), globals(NULL),
valueBuffers(NULL), tabulatedFunctionParams(NULL), system(system) {
valueBuffers(NULL), system(system) {
}
~CudaCalcCustomGBForceKernel();
/**
......@@ -776,7 +775,6 @@ private:
CudaArray* longEnergyDerivs;
CudaArray* globals;
CudaArray* valueBuffers;
CudaArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<CudaArray*> tabulatedFunctions;
......@@ -838,7 +836,7 @@ class CudaCalcCustomHbondForceKernel : public CalcCustomHbondForceKernel {
public:
CudaCalcCustomHbondForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomHbondForceKernel(name, platform),
hasInitializedKernel(false), cu(cu), donorParams(NULL), acceptorParams(NULL), donors(NULL), acceptors(NULL),
globals(NULL), donorExclusions(NULL), acceptorExclusions(NULL), tabulatedFunctionParams(NULL), system(system) {
globals(NULL), donorExclusions(NULL), acceptorExclusions(NULL), system(system) {
}
~CudaCalcCustomHbondForceKernel();
/**
......@@ -875,7 +873,6 @@ private:
CudaArray* acceptors;
CudaArray* donorExclusions;
CudaArray* acceptorExclusions;
CudaArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<CudaArray*> tabulatedFunctions;
......@@ -890,7 +887,7 @@ private:
class CudaCalcCustomCompoundBondForceKernel : public CalcCustomCompoundBondForceKernel {
public:
CudaCalcCustomCompoundBondForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomCompoundBondForceKernel(name, platform),
cu(cu), params(NULL), globals(NULL), tabulatedFunctionParams(NULL), system(system) {
cu(cu), params(NULL), globals(NULL), system(system) {
}
~CudaCalcCustomCompoundBondForceKernel();
/**
......@@ -922,7 +919,6 @@ private:
CudaContext& cu;
CudaParameterSet* params;
CudaArray* globals;
CudaArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<CudaArray*> tabulatedFunctions;
......
......@@ -37,22 +37,21 @@ CudaExpressionUtilities::CudaExpressionUtilities(CudaContext& context) : context
}
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) {
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
vector<pair<ExpressionTreeNode, string> > variableNodes;
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, functionNames, prefix, functionParams, tempType);
return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
}
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) {
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
stringstream out;
vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second);
vector<pair<ExpressionTreeNode, string> > temps = variables;
vector<vector<double> > functionParams = computeFunctionParameters(functions);
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
......@@ -61,7 +60,7 @@ string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpres
}
void CudaExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& functionParams,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<vector<double> >& functionParams,
const vector<ParsedExpression>& allExpressions, const string& tempType) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
......@@ -104,22 +103,26 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
temps.push_back(make_pair(*nodes[j], name2));
}
out << "{\n";
vector<string> paramsFloat, paramsInt;
for (int j = 0; j < (int) functionParams[i].size(); j++) {
paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
paramsInt.push_back(context.intToString((int) functionParams[i][j]));
}
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n";
out << "x = (x-params.x)*params.z;\n";
out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x-" << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) params.w);\n";
out << "index = min(index, (int) " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0)
out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n";
out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
else
out << nodeNames[j] << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n";
out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
}
out << "}\n";
}
......@@ -127,9 +130,8 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < params.x) {\n";
out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
out << "int index = (int) round(x);\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
......@@ -140,11 +142,10 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int xsize = (int) params.x;\n";
out << "int ysize = (int) params.y;\n";
out << "int xsize = (int) " << paramsInt[0] << ";\n";
out << "int ysize = (int) " << paramsInt[1] << ";\n";
out << "int index = x+y*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
......@@ -155,13 +156,12 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int z = (int) round(" << getTempName(node.getChildren()[2], temps) << ");\n";
out << "int xsize = (int) params.x;\n";
out << "int ysize = (int) params.y;\n";
out << "int zsize = (int) params.z;\n";
out << "int xsize = (int) " << paramsInt[0] << ";\n";
out << "int ysize = (int) " << paramsInt[1] << ";\n";
out << "int zsize = (int) " << paramsInt[2] << ";\n";
out << "int index = x+(y+z*ysize)*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
......@@ -452,35 +452,41 @@ vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const Tabulat
throw OpenMMException("computeFunctionCoefficients: Unknown function type");
}
vector<float4> CudaExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
vector<float4> params(functions.size());
vector<vector<double> > CudaExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
vector<vector<double> > params(functions.size());
for (int i = 0; i < (int) functions.size(); i++) {
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]);
vector<double> values;
double min, max;
fn.getFunctionParameters(values, min, max);
params[i] = make_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2);
params[i].push_back(min);
params[i].push_back(max);
params[i].push_back((values.size()-1)/(max-min));
params[i].push_back(values.size()-2);
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
vector<double> values;
fn.getFunctionParameters(values);
params[i] = make_float4((float) values.size(), 0.0f, 0.0f, 0.0f);
params[i].push_back(values.size());
}
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
int xsize, ysize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, values);
params[i] = make_float4(xsize, ysize, 0.0f, 0.0f);
params[i].push_back(xsize);
params[i].push_back(ysize);
}
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
int xsize, ysize, zsize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, zsize, values);
params[i] = make_float4(xsize, ysize, zsize, 0.0f);
params[i].push_back(xsize);
params[i].push_back(ysize);
params[i].push_back(zsize);
}
else
throw OpenMMException("computeFunctionParameters: Unknown function type");
......
This diff is collapsed.
......@@ -56,12 +56,11 @@ public:
* @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real")
*/
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="real");
const std::string& prefix, const std::string& tempType="real");
/**
* Generate the source code for calculating a set of expressions.
*
......@@ -71,12 +70,11 @@ public:
* @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "float")
*/
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="float");
const std::string& prefix, const std::string& tempType="float");
/**
* Calculate the spline coefficients for a tabulated function that appears in expressions.
*
......@@ -85,13 +83,6 @@ public:
* @return the spline coefficients
*/
std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function, int& width);
/**
* Given the list of TabulatedFunctions used by a Force, create the parameter array describing them.
*
* @param functions the list of functions to include in the array
* @return the parameter array
*/
std::vector<mm_float4> computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
/**
* Get a Lepton::CustomFunction that can be used to represent a TabulatedFunction when parsing expressions.
*
......@@ -121,12 +112,13 @@ private:
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
OpenCLContext& context;
FunctionPlaceholder fp1, fp2, fp3;
};
......
......@@ -639,7 +639,7 @@ private:
class OpenCLCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel {
public:
OpenCLCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomNonbondedForceKernel(name, platform),
cl(cl), params(NULL), globals(NULL), tabulatedFunctionParams(NULL), interactionGroupData(NULL), forceCopy(NULL), system(system), hasInitializedKernel(false) {
cl(cl), params(NULL), globals(NULL), interactionGroupData(NULL), forceCopy(NULL), system(system), hasInitializedKernel(false) {
}
~OpenCLCalcCustomNonbondedForceKernel();
/**
......@@ -670,7 +670,6 @@ private:
OpenCLContext& cl;
OpenCLParameterSet* params;
OpenCLArray* globals;
OpenCLArray* tabulatedFunctionParams;
OpenCLArray* interactionGroupData;
cl::Kernel interactionGroupKernel;
std::vector<void*> interactionGroupArgs;
......@@ -742,7 +741,7 @@ class OpenCLCalcCustomGBForceKernel : public CalcCustomGBForceKernel {
public:
OpenCLCalcCustomGBForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomGBForceKernel(name, platform),
hasInitializedKernels(false), cl(cl), params(NULL), computedValues(NULL), energyDerivs(NULL), energyDerivChain(NULL), longEnergyDerivs(NULL), globals(NULL),
valueBuffers(NULL), longValueBuffers(NULL), tabulatedFunctionParams(NULL), system(system) {
valueBuffers(NULL), longValueBuffers(NULL), system(system) {
}
~OpenCLCalcCustomGBForceKernel();
/**
......@@ -780,7 +779,6 @@ private:
OpenCLArray* globals;
OpenCLArray* valueBuffers;
OpenCLArray* longValueBuffers;
OpenCLArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray*> tabulatedFunctions;
......@@ -841,8 +839,7 @@ class OpenCLCalcCustomHbondForceKernel : public CalcCustomHbondForceKernel {
public:
OpenCLCalcCustomHbondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomHbondForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), donorParams(NULL), acceptorParams(NULL), donors(NULL), acceptors(NULL),
donorBufferIndices(NULL), acceptorBufferIndices(NULL), globals(NULL), donorExclusions(NULL), acceptorExclusions(NULL),
tabulatedFunctionParams(NULL), system(system) {
donorBufferIndices(NULL), acceptorBufferIndices(NULL), globals(NULL), donorExclusions(NULL), acceptorExclusions(NULL), system(system) {
}
~OpenCLCalcCustomHbondForceKernel();
/**
......@@ -881,7 +878,6 @@ private:
OpenCLArray* acceptorBufferIndices;
OpenCLArray* donorExclusions;
OpenCLArray* acceptorExclusions;
OpenCLArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray*> tabulatedFunctions;
......@@ -895,7 +891,7 @@ private:
class OpenCLCalcCustomCompoundBondForceKernel : public CalcCustomCompoundBondForceKernel {
public:
OpenCLCalcCustomCompoundBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomCompoundBondForceKernel(name, platform),
cl(cl), params(NULL), globals(NULL), tabulatedFunctionParams(NULL), system(system) {
cl(cl), params(NULL), globals(NULL), system(system) {
}
~OpenCLCalcCustomCompoundBondForceKernel();
/**
......@@ -927,7 +923,6 @@ private:
OpenCLContext& cl;
OpenCLParameterSet* params;
OpenCLArray* globals;
OpenCLArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray*> tabulatedFunctions;
......
......@@ -37,22 +37,21 @@ OpenCLExpressionUtilities::OpenCLExpressionUtilities(OpenCLContext& context) : c
}
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) {
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
vector<pair<ExpressionTreeNode, string> > variableNodes;
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, functionNames, prefix, functionParams, tempType);
return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
}
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix,
const string& functionParams, const string& tempType) {
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
stringstream out;
vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second);
vector<pair<ExpressionTreeNode, string> > temps = variables;
vector<vector<double> > functionParams = computeFunctionParameters(functions);
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
......@@ -61,7 +60,7 @@ string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpr
}
void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& functionParams,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<vector<double> >& functionParams,
const vector<ParsedExpression>& allExpressions, const string& tempType) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
......@@ -104,22 +103,26 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
temps.push_back(make_pair(*nodes[j], name2));
}
out << "{\n";
vector<string> paramsFloat, paramsInt;
for (int j = 0; j < (int) functionParams[i].size(); j++) {
paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
paramsInt.push_back(context.intToString((int) functionParams[i][j]));
}
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n";
out << "x = (x-params.x)*params.z;\n";
out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x-" << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) params.w);\n";
out << "index = min(index, " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n";
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0)
out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n";
out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
else
out << nodeNames[j] << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n";
out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
}
out << "}\n";
}
......@@ -127,9 +130,8 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < params.x) {\n";
out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
out << "int index = (int) round(x);\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << "}\n";
......@@ -140,11 +142,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int xsize = (int) params.x;\n";
out << "int ysize = (int) params.y;\n";
out << "int xsize = " << paramsInt[0] << ";\n";
out << "int ysize = " << paramsInt[1] << ";\n";
out << "int index = x+y*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
......@@ -155,13 +156,12 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int z = (int) round(" << getTempName(node.getChildren()[2], temps) << ");\n";
out << "int xsize = (int) params.x;\n";
out << "int ysize = (int) params.y;\n";
out << "int zsize = (int) params.z;\n";
out << "int xsize = " << paramsInt[0] << ";\n";
out << "int ysize = " << paramsInt[1] << ";\n";
out << "int zsize = " << paramsInt[2] << ";\n";
out << "int index = x+(y+z*ysize)*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
......@@ -452,35 +452,41 @@ vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const Tabul
throw OpenMMException("computeFunctionCoefficients: Unknown function type");
}
vector<mm_float4> OpenCLExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
vector<mm_float4> params(functions.size());
vector<vector<double> > OpenCLExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
vector<vector<double> > params(functions.size());
for (int i = 0; i < (int) functions.size(); i++) {
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]);
vector<double> values;
double min, max;
fn.getFunctionParameters(values, min, max);
params[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2);
params[i].push_back(min);
params[i].push_back(max);
params[i].push_back((values.size()-1)/(max-min));
params[i].push_back(values.size()-2);
}
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
vector<double> values;
fn.getFunctionParameters(values);
params[i] = mm_float4((float) values.size(), 0.0f, 0.0f, 0.0f);
params[i].push_back(values.size());
}
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
int xsize, ysize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, values);
params[i] = mm_float4(xsize, ysize, 0.0f, 0.0f);
params[i].push_back(xsize);
params[i].push_back(ysize);
}
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
int xsize, ysize, zsize;
vector<double> values;
fn.getFunctionParameters(xsize, ysize, zsize, values);
params[i] = mm_float4(xsize, ysize, zsize, 0.0f);
params[i].push_back(xsize);
params[i].push_back(ysize);
params[i].push_back(zsize);
}
else
throw OpenMMException("computeFunctionParameters: Unknown function type");
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment