Commit 0a1a011d authored by peastman's avatar peastman
Browse files

Tabulated function parameters are hardcoded in the kernel instead of being...

Tabulated function parameters are hardcoded in the kernel instead of being stored in an array.  This makes the code simpler and may help performance slightly.
parent 5d4dcb42
...@@ -56,12 +56,11 @@ public: ...@@ -56,12 +56,11 @@ public:
* @param functions the tabulated functions that may appear in the expressions * @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions * @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables * @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real") * @param tempType the type of value to use for temporary variables (defaults to "real")
*/ */
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables, std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="real"); const std::string& prefix, const std::string& tempType="real");
/** /**
* Generate the source code for calculating a set of expressions. * Generate the source code for calculating a set of expressions.
* *
...@@ -71,12 +70,11 @@ public: ...@@ -71,12 +70,11 @@ public:
* @param functions the tabulated functions that may appear in the expressions * @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions * @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables * @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real") * @param tempType the type of value to use for temporary variables (defaults to "real")
*/ */
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables, std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="real"); const std::string& prefix, const std::string& tempType="real");
/** /**
* Calculate the spline coefficients for a tabulated function that appears in expressions. * Calculate the spline coefficients for a tabulated function that appears in expressions.
* *
...@@ -85,13 +83,6 @@ public: ...@@ -85,13 +83,6 @@ public:
* @return the spline coefficients * @return the spline coefficients
*/ */
std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function, int& width); std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function, int& width);
/**
* Given the list of TabulatedFunctions used by a Force, create the parameter array describing them.
*
* @param functions the list of functions to include in the array
* @return the parameter array
*/
std::vector<float4> computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
/** /**
* Get a Lepton::CustomFunction that can be used to represent a TabulatedFunction when parsing expressions. * Get a Lepton::CustomFunction that can be used to represent a TabulatedFunction when parsing expressions.
* *
...@@ -121,12 +112,13 @@ private: ...@@ -121,12 +112,13 @@ private:
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node, void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType); const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps); std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::vector<const Lepton::ExpressionTreeNode*>& nodes); std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers); std::map<int, const Lepton::ExpressionTreeNode*>& powers);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
CudaContext& context; CudaContext& context;
FunctionPlaceholder fp1, fp2, fp3; FunctionPlaceholder fp1, fp2, fp3;
}; };
......
...@@ -638,7 +638,7 @@ private: ...@@ -638,7 +638,7 @@ private:
class CudaCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel { class CudaCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel {
public: public:
CudaCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomNonbondedForceKernel(name, platform), CudaCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomNonbondedForceKernel(name, platform),
cu(cu), params(NULL), globals(NULL), tabulatedFunctionParams(NULL), interactionGroupData(NULL), forceCopy(NULL), system(system), hasInitializedKernel(false) { cu(cu), params(NULL), globals(NULL), interactionGroupData(NULL), forceCopy(NULL), system(system), hasInitializedKernel(false) {
} }
~CudaCalcCustomNonbondedForceKernel(); ~CudaCalcCustomNonbondedForceKernel();
/** /**
...@@ -669,7 +669,6 @@ private: ...@@ -669,7 +669,6 @@ private:
CudaContext& cu; CudaContext& cu;
CudaParameterSet* params; CudaParameterSet* params;
CudaArray* globals; CudaArray* globals;
CudaArray* tabulatedFunctionParams;
CudaArray* interactionGroupData; CudaArray* interactionGroupData;
CUfunction interactionGroupKernel; CUfunction interactionGroupKernel;
std::vector<void*> interactionGroupArgs; std::vector<void*> interactionGroupArgs;
...@@ -739,7 +738,7 @@ class CudaCalcCustomGBForceKernel : public CalcCustomGBForceKernel { ...@@ -739,7 +738,7 @@ class CudaCalcCustomGBForceKernel : public CalcCustomGBForceKernel {
public: public:
CudaCalcCustomGBForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomGBForceKernel(name, platform), CudaCalcCustomGBForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomGBForceKernel(name, platform),
hasInitializedKernels(false), cu(cu), params(NULL), computedValues(NULL), energyDerivs(NULL), energyDerivChain(NULL), longEnergyDerivs(NULL), globals(NULL), hasInitializedKernels(false), cu(cu), params(NULL), computedValues(NULL), energyDerivs(NULL), energyDerivChain(NULL), longEnergyDerivs(NULL), globals(NULL),
valueBuffers(NULL), tabulatedFunctionParams(NULL), system(system) { valueBuffers(NULL), system(system) {
} }
~CudaCalcCustomGBForceKernel(); ~CudaCalcCustomGBForceKernel();
/** /**
...@@ -776,7 +775,6 @@ private: ...@@ -776,7 +775,6 @@ private:
CudaArray* longEnergyDerivs; CudaArray* longEnergyDerivs;
CudaArray* globals; CudaArray* globals;
CudaArray* valueBuffers; CudaArray* valueBuffers;
CudaArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<CudaArray*> tabulatedFunctions; std::vector<CudaArray*> tabulatedFunctions;
...@@ -838,7 +836,7 @@ class CudaCalcCustomHbondForceKernel : public CalcCustomHbondForceKernel { ...@@ -838,7 +836,7 @@ class CudaCalcCustomHbondForceKernel : public CalcCustomHbondForceKernel {
public: public:
CudaCalcCustomHbondForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomHbondForceKernel(name, platform), CudaCalcCustomHbondForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomHbondForceKernel(name, platform),
hasInitializedKernel(false), cu(cu), donorParams(NULL), acceptorParams(NULL), donors(NULL), acceptors(NULL), hasInitializedKernel(false), cu(cu), donorParams(NULL), acceptorParams(NULL), donors(NULL), acceptors(NULL),
globals(NULL), donorExclusions(NULL), acceptorExclusions(NULL), tabulatedFunctionParams(NULL), system(system) { globals(NULL), donorExclusions(NULL), acceptorExclusions(NULL), system(system) {
} }
~CudaCalcCustomHbondForceKernel(); ~CudaCalcCustomHbondForceKernel();
/** /**
...@@ -875,7 +873,6 @@ private: ...@@ -875,7 +873,6 @@ private:
CudaArray* acceptors; CudaArray* acceptors;
CudaArray* donorExclusions; CudaArray* donorExclusions;
CudaArray* acceptorExclusions; CudaArray* acceptorExclusions;
CudaArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<CudaArray*> tabulatedFunctions; std::vector<CudaArray*> tabulatedFunctions;
...@@ -890,7 +887,7 @@ private: ...@@ -890,7 +887,7 @@ private:
class CudaCalcCustomCompoundBondForceKernel : public CalcCustomCompoundBondForceKernel { class CudaCalcCustomCompoundBondForceKernel : public CalcCustomCompoundBondForceKernel {
public: public:
CudaCalcCustomCompoundBondForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomCompoundBondForceKernel(name, platform), CudaCalcCustomCompoundBondForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomCompoundBondForceKernel(name, platform),
cu(cu), params(NULL), globals(NULL), tabulatedFunctionParams(NULL), system(system) { cu(cu), params(NULL), globals(NULL), system(system) {
} }
~CudaCalcCustomCompoundBondForceKernel(); ~CudaCalcCustomCompoundBondForceKernel();
/** /**
...@@ -922,7 +919,6 @@ private: ...@@ -922,7 +919,6 @@ private:
CudaContext& cu; CudaContext& cu;
CudaParameterSet* params; CudaParameterSet* params;
CudaArray* globals; CudaArray* globals;
CudaArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<CudaArray*> tabulatedFunctions; std::vector<CudaArray*> tabulatedFunctions;
......
...@@ -37,22 +37,21 @@ CudaExpressionUtilities::CudaExpressionUtilities(CudaContext& context) : context ...@@ -37,22 +37,21 @@ CudaExpressionUtilities::CudaExpressionUtilities(CudaContext& context) : context
} }
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables, string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
const string& functionParams, const string& tempType) {
vector<pair<ExpressionTreeNode, string> > variableNodes; vector<pair<ExpressionTreeNode, string> > variableNodes;
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter) for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second)); variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, functionNames, prefix, functionParams, tempType); return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
} }
string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables, string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
const string& functionParams, const string& tempType) {
stringstream out; stringstream out;
vector<ParsedExpression> allExpressions; vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second); allExpressions.push_back(iter->second);
vector<pair<ExpressionTreeNode, string> > temps = variables; vector<pair<ExpressionTreeNode, string> > temps = variables;
vector<vector<double> > functionParams = computeFunctionParameters(functions);
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) { for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType); processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n"; out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
...@@ -61,7 +60,7 @@ string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpres ...@@ -61,7 +60,7 @@ string CudaExpressionUtilities::createExpressions(const map<string, ParsedExpres
} }
void CudaExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps, void CudaExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& functionParams, const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<vector<double> >& functionParams,
const vector<ParsedExpression>& allExpressions, const string& tempType) { const vector<ParsedExpression>& allExpressions, const string& tempType) {
for (int i = 0; i < (int) temps.size(); i++) for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node) if (temps[i].first == node)
...@@ -104,22 +103,26 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -104,22 +103,26 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
temps.push_back(make_pair(*nodes[j], name2)); temps.push_back(make_pair(*nodes[j], name2));
} }
out << "{\n"; out << "{\n";
vector<string> paramsFloat, paramsInt;
for (int j = 0; j < (int) functionParams[i].size(); j++) {
paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
paramsInt.push_back(context.intToString((int) functionParams[i][j]));
}
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) { if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n"; out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x-params.x)*params.z;\n"; out << "x = (x-" << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n"; out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) params.w);\n"; out << "index = min(index, (int) " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n"; out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n"; out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n"; out << "real a = 1.0f-b;\n";
for (int j = 0; j < nodes.size(); j++) { for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0) if (derivOrder[0] == 0)
out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n"; out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
else else
out << nodeNames[j] << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n"; out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
} }
out << "}\n"; out << "}\n";
} }
...@@ -127,9 +130,8 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -127,9 +130,8 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
for (int j = 0; j < nodes.size(); j++) { for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0) { if (derivOrder[0] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < params.x) {\n"; out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
out << "int index = (int) round(x);\n"; out << "int index = (int) round(x);\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n"; out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << "}\n"; out << "}\n";
...@@ -140,11 +142,10 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -140,11 +142,10 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
for (int j = 0; j < nodes.size(); j++) { for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) { if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n"; out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n"; out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int xsize = (int) params.x;\n"; out << "int xsize = (int) " << paramsInt[0] << ";\n";
out << "int ysize = (int) params.y;\n"; out << "int ysize = (int) " << paramsInt[1] << ";\n";
out << "int index = x+y*xsize;\n"; out << "int index = x+y*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize)\n"; out << "if (index >= 0 && index < xsize*ysize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n"; out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
...@@ -155,13 +156,12 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express ...@@ -155,13 +156,12 @@ void CudaExpressionUtilities::processExpression(stringstream& out, const Express
for (int j = 0; j < nodes.size(); j++) { for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) { if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n"; out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n"; out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int z = (int) round(" << getTempName(node.getChildren()[2], temps) << ");\n"; out << "int z = (int) round(" << getTempName(node.getChildren()[2], temps) << ");\n";
out << "int xsize = (int) params.x;\n"; out << "int xsize = (int) " << paramsInt[0] << ";\n";
out << "int ysize = (int) params.y;\n"; out << "int ysize = (int) " << paramsInt[1] << ";\n";
out << "int zsize = (int) params.z;\n"; out << "int zsize = (int) " << paramsInt[2] << ";\n";
out << "int index = x+(y+z*ysize)*xsize;\n"; out << "int index = x+(y+z*ysize)*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize*zsize)\n"; out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n"; out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
...@@ -452,35 +452,41 @@ vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const Tabulat ...@@ -452,35 +452,41 @@ vector<float> CudaExpressionUtilities::computeFunctionCoefficients(const Tabulat
throw OpenMMException("computeFunctionCoefficients: Unknown function type"); throw OpenMMException("computeFunctionCoefficients: Unknown function type");
} }
vector<float4> CudaExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) { vector<vector<double> > CudaExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
vector<float4> params(functions.size()); vector<vector<double> > params(functions.size());
for (int i = 0; i < (int) functions.size(); i++) { for (int i = 0; i < (int) functions.size(); i++) {
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) { if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]); const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]);
vector<double> values; vector<double> values;
double min, max; double min, max;
fn.getFunctionParameters(values, min, max); fn.getFunctionParameters(values, min, max);
params[i] = make_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2); params[i].push_back(min);
params[i].push_back(max);
params[i].push_back((values.size()-1)/(max-min));
params[i].push_back(values.size()-2);
} }
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]); const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
vector<double> values; vector<double> values;
fn.getFunctionParameters(values); fn.getFunctionParameters(values);
params[i] = make_float4((float) values.size(), 0.0f, 0.0f, 0.0f); params[i].push_back(values.size());
} }
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]); const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
int xsize, ysize; int xsize, ysize;
vector<double> values; vector<double> values;
fn.getFunctionParameters(xsize, ysize, values); fn.getFunctionParameters(xsize, ysize, values);
params[i] = make_float4(xsize, ysize, 0.0f, 0.0f); params[i].push_back(xsize);
params[i].push_back(ysize);
} }
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]); const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
int xsize, ysize, zsize; int xsize, ysize, zsize;
vector<double> values; vector<double> values;
fn.getFunctionParameters(xsize, ysize, zsize, values); fn.getFunctionParameters(xsize, ysize, zsize, values);
params[i] = make_float4(xsize, ysize, zsize, 0.0f); params[i].push_back(xsize);
params[i].push_back(ysize);
params[i].push_back(zsize);
} }
else else
throw OpenMMException("computeFunctionParameters: Unknown function type"); throw OpenMMException("computeFunctionParameters: Unknown function type");
......
This diff is collapsed.
...@@ -56,12 +56,11 @@ public: ...@@ -56,12 +56,11 @@ public:
* @param functions the tabulated functions that may appear in the expressions * @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions * @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables * @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "real") * @param tempType the type of value to use for temporary variables (defaults to "real")
*/ */
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables, std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="real"); const std::string& prefix, const std::string& tempType="real");
/** /**
* Generate the source code for calculating a set of expressions. * Generate the source code for calculating a set of expressions.
* *
...@@ -71,12 +70,11 @@ public: ...@@ -71,12 +70,11 @@ public:
* @param functions the tabulated functions that may appear in the expressions * @param functions the tabulated functions that may appear in the expressions
* @param functionNames defines the variable name for each tabulated function that may appear in the expressions * @param functionNames defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables * @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
* @param tempType the type of value to use for temporary variables (defaults to "float") * @param tempType the type of value to use for temporary variables (defaults to "float")
*/ */
std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables, std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variables,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::string& tempType="float"); const std::string& prefix, const std::string& tempType="float");
/** /**
* Calculate the spline coefficients for a tabulated function that appears in expressions. * Calculate the spline coefficients for a tabulated function that appears in expressions.
* *
...@@ -85,13 +83,6 @@ public: ...@@ -85,13 +83,6 @@ public:
* @return the spline coefficients * @return the spline coefficients
*/ */
std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function, int& width); std::vector<float> computeFunctionCoefficients(const TabulatedFunction& function, int& width);
/**
* Given the list of TabulatedFunctions used by a Force, create the parameter array describing them.
*
* @param functions the list of functions to include in the array
* @return the parameter array
*/
std::vector<mm_float4> computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
/** /**
* Get a Lepton::CustomFunction that can be used to represent a TabulatedFunction when parsing expressions. * Get a Lepton::CustomFunction that can be used to represent a TabulatedFunction when parsing expressions.
* *
...@@ -121,12 +112,13 @@ private: ...@@ -121,12 +112,13 @@ private:
void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node, void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps,
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::string& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType); const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps); std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::vector<const Lepton::ExpressionTreeNode*>& nodes); std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers); std::map<int, const Lepton::ExpressionTreeNode*>& powers);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
OpenCLContext& context; OpenCLContext& context;
FunctionPlaceholder fp1, fp2, fp3; FunctionPlaceholder fp1, fp2, fp3;
}; };
......
...@@ -639,7 +639,7 @@ private: ...@@ -639,7 +639,7 @@ private:
class OpenCLCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel { class OpenCLCalcCustomNonbondedForceKernel : public CalcCustomNonbondedForceKernel {
public: public:
OpenCLCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomNonbondedForceKernel(name, platform), OpenCLCalcCustomNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomNonbondedForceKernel(name, platform),
cl(cl), params(NULL), globals(NULL), tabulatedFunctionParams(NULL), interactionGroupData(NULL), forceCopy(NULL), system(system), hasInitializedKernel(false) { cl(cl), params(NULL), globals(NULL), interactionGroupData(NULL), forceCopy(NULL), system(system), hasInitializedKernel(false) {
} }
~OpenCLCalcCustomNonbondedForceKernel(); ~OpenCLCalcCustomNonbondedForceKernel();
/** /**
...@@ -670,7 +670,6 @@ private: ...@@ -670,7 +670,6 @@ private:
OpenCLContext& cl; OpenCLContext& cl;
OpenCLParameterSet* params; OpenCLParameterSet* params;
OpenCLArray* globals; OpenCLArray* globals;
OpenCLArray* tabulatedFunctionParams;
OpenCLArray* interactionGroupData; OpenCLArray* interactionGroupData;
cl::Kernel interactionGroupKernel; cl::Kernel interactionGroupKernel;
std::vector<void*> interactionGroupArgs; std::vector<void*> interactionGroupArgs;
...@@ -742,7 +741,7 @@ class OpenCLCalcCustomGBForceKernel : public CalcCustomGBForceKernel { ...@@ -742,7 +741,7 @@ class OpenCLCalcCustomGBForceKernel : public CalcCustomGBForceKernel {
public: public:
OpenCLCalcCustomGBForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomGBForceKernel(name, platform), OpenCLCalcCustomGBForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomGBForceKernel(name, platform),
hasInitializedKernels(false), cl(cl), params(NULL), computedValues(NULL), energyDerivs(NULL), energyDerivChain(NULL), longEnergyDerivs(NULL), globals(NULL), hasInitializedKernels(false), cl(cl), params(NULL), computedValues(NULL), energyDerivs(NULL), energyDerivChain(NULL), longEnergyDerivs(NULL), globals(NULL),
valueBuffers(NULL), longValueBuffers(NULL), tabulatedFunctionParams(NULL), system(system) { valueBuffers(NULL), longValueBuffers(NULL), system(system) {
} }
~OpenCLCalcCustomGBForceKernel(); ~OpenCLCalcCustomGBForceKernel();
/** /**
...@@ -780,7 +779,6 @@ private: ...@@ -780,7 +779,6 @@ private:
OpenCLArray* globals; OpenCLArray* globals;
OpenCLArray* valueBuffers; OpenCLArray* valueBuffers;
OpenCLArray* longValueBuffers; OpenCLArray* longValueBuffers;
OpenCLArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues; std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray*> tabulatedFunctions; std::vector<OpenCLArray*> tabulatedFunctions;
...@@ -841,8 +839,7 @@ class OpenCLCalcCustomHbondForceKernel : public CalcCustomHbondForceKernel { ...@@ -841,8 +839,7 @@ class OpenCLCalcCustomHbondForceKernel : public CalcCustomHbondForceKernel {
public: public:
OpenCLCalcCustomHbondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomHbondForceKernel(name, platform), OpenCLCalcCustomHbondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomHbondForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), donorParams(NULL), acceptorParams(NULL), donors(NULL), acceptors(NULL), hasInitializedKernel(false), cl(cl), donorParams(NULL), acceptorParams(NULL), donors(NULL), acceptors(NULL),
donorBufferIndices(NULL), acceptorBufferIndices(NULL), globals(NULL), donorExclusions(NULL), acceptorExclusions(NULL), donorBufferIndices(NULL), acceptorBufferIndices(NULL), globals(NULL), donorExclusions(NULL), acceptorExclusions(NULL), system(system) {
tabulatedFunctionParams(NULL), system(system) {
} }
~OpenCLCalcCustomHbondForceKernel(); ~OpenCLCalcCustomHbondForceKernel();
/** /**
...@@ -881,7 +878,6 @@ private: ...@@ -881,7 +878,6 @@ private:
OpenCLArray* acceptorBufferIndices; OpenCLArray* acceptorBufferIndices;
OpenCLArray* donorExclusions; OpenCLArray* donorExclusions;
OpenCLArray* acceptorExclusions; OpenCLArray* acceptorExclusions;
OpenCLArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues; std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray*> tabulatedFunctions; std::vector<OpenCLArray*> tabulatedFunctions;
...@@ -895,7 +891,7 @@ private: ...@@ -895,7 +891,7 @@ private:
class OpenCLCalcCustomCompoundBondForceKernel : public CalcCustomCompoundBondForceKernel { class OpenCLCalcCustomCompoundBondForceKernel : public CalcCustomCompoundBondForceKernel {
public: public:
OpenCLCalcCustomCompoundBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomCompoundBondForceKernel(name, platform), OpenCLCalcCustomCompoundBondForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcCustomCompoundBondForceKernel(name, platform),
cl(cl), params(NULL), globals(NULL), tabulatedFunctionParams(NULL), system(system) { cl(cl), params(NULL), globals(NULL), system(system) {
} }
~OpenCLCalcCustomCompoundBondForceKernel(); ~OpenCLCalcCustomCompoundBondForceKernel();
/** /**
...@@ -927,7 +923,6 @@ private: ...@@ -927,7 +923,6 @@ private:
OpenCLContext& cl; OpenCLContext& cl;
OpenCLParameterSet* params; OpenCLParameterSet* params;
OpenCLArray* globals; OpenCLArray* globals;
OpenCLArray* tabulatedFunctionParams;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<cl_float> globalParamValues; std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray*> tabulatedFunctions; std::vector<OpenCLArray*> tabulatedFunctions;
......
...@@ -37,22 +37,21 @@ OpenCLExpressionUtilities::OpenCLExpressionUtilities(OpenCLContext& context) : c ...@@ -37,22 +37,21 @@ OpenCLExpressionUtilities::OpenCLExpressionUtilities(OpenCLContext& context) : c
} }
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables, string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
const string& functionParams, const string& tempType) {
vector<pair<ExpressionTreeNode, string> > variableNodes; vector<pair<ExpressionTreeNode, string> > variableNodes;
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter) for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second)); variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return createExpressions(expressions, variableNodes, functions, functionNames, prefix, functionParams, tempType); return createExpressions(expressions, variableNodes, functions, functionNames, prefix, tempType);
} }
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables, string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const vector<pair<ExpressionTreeNode, string> >& variables,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& tempType) {
const string& functionParams, const string& tempType) {
stringstream out; stringstream out;
vector<ParsedExpression> allExpressions; vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second); allExpressions.push_back(iter->second);
vector<pair<ExpressionTreeNode, string> > temps = variables; vector<pair<ExpressionTreeNode, string> > temps = variables;
vector<vector<double> > functionParams = computeFunctionParameters(functions);
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) { for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType); processExpression(out, iter->second.getRootNode(), temps, functions, functionNames, prefix, functionParams, allExpressions, tempType);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n"; out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
...@@ -61,7 +60,7 @@ string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpr ...@@ -61,7 +60,7 @@ string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpr
} }
void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps, void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const string& functionParams, const vector<const TabulatedFunction*>& functions, const vector<pair<string, string> >& functionNames, const string& prefix, const vector<vector<double> >& functionParams,
const vector<ParsedExpression>& allExpressions, const string& tempType) { const vector<ParsedExpression>& allExpressions, const string& tempType) {
for (int i = 0; i < (int) temps.size(); i++) for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node) if (temps[i].first == node)
...@@ -104,22 +103,26 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -104,22 +103,26 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
temps.push_back(make_pair(*nodes[j], name2)); temps.push_back(make_pair(*nodes[j], name2));
} }
out << "{\n"; out << "{\n";
vector<string> paramsFloat, paramsInt;
for (int j = 0; j < (int) functionParams[i].size(); j++) {
paramsFloat.push_back(context.doubleToString(functionParams[i][j]));
paramsInt.push_back(context.intToString((int) functionParams[i][j]));
}
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) { if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= params.x && x <= params.y) {\n"; out << "if (x >= " << paramsFloat[0] << " && x <= " << paramsFloat[1] << ") {\n";
out << "x = (x-params.x)*params.z;\n"; out << "x = (x-" << paramsFloat[0] << ")*" << paramsFloat[2] << ";\n";
out << "int index = (int) (floor(x));\n"; out << "int index = (int) (floor(x));\n";
out << "index = min(index, (int) params.w);\n"; out << "index = min(index, " << paramsInt[3] << ");\n";
out << "float4 coeff = " << functionNames[i].second << "[index];\n"; out << "float4 coeff = " << functionNames[i].second << "[index];\n";
out << "real b = x-index;\n"; out << "real b = x-index;\n";
out << "real a = 1.0f-b;\n"; out << "real a = 1.0f-b;\n";
for (int j = 0; j < nodes.size(); j++) { for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0) if (derivOrder[0] == 0)
out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(params.z*params.z);\n"; out << nodeNames[j] << " = a*coeff.x+b*coeff.y+((a*a*a-a)*coeff.z+(b*b*b-b)*coeff.w)/(" << paramsFloat[2] << "*" << paramsFloat[2] << ");\n";
else else
out << nodeNames[j] << " = (coeff.y-coeff.x)*params.z+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/params.z;\n"; out << nodeNames[j] << " = (coeff.y-coeff.x)*" << paramsFloat[2] << "+((1.0f-3.0f*a*a)*coeff.z+(3.0f*b*b-1.0f)*coeff.w)/" << paramsFloat[2] << ";\n";
} }
out << "}\n"; out << "}\n";
} }
...@@ -127,9 +130,8 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -127,9 +130,8 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
for (int j = 0; j < nodes.size(); j++) { for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0) { if (derivOrder[0] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n"; out << "real x = " << getTempName(node.getChildren()[0], temps) << ";\n";
out << "if (x >= 0 && x < params.x) {\n"; out << "if (x >= 0 && x < " << paramsInt[0] << ") {\n";
out << "int index = (int) round(x);\n"; out << "int index = (int) round(x);\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n"; out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
out << "}\n"; out << "}\n";
...@@ -140,11 +142,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -140,11 +142,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
for (int j = 0; j < nodes.size(); j++) { for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0) { if (derivOrder[0] == 0 && derivOrder[1] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n"; out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n"; out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int xsize = (int) params.x;\n"; out << "int xsize = " << paramsInt[0] << ";\n";
out << "int ysize = (int) params.y;\n"; out << "int ysize = " << paramsInt[1] << ";\n";
out << "int index = x+y*xsize;\n"; out << "int index = x+y*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize)\n"; out << "if (index >= 0 && index < xsize*ysize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n"; out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
...@@ -155,13 +156,12 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -155,13 +156,12 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
for (int j = 0; j < nodes.size(); j++) { for (int j = 0; j < nodes.size(); j++) {
const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder(); const vector<int>& derivOrder = dynamic_cast<const Operation::Custom*>(&nodes[j]->getOperation())->getDerivOrder();
if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) { if (derivOrder[0] == 0 && derivOrder[1] == 0 && derivOrder[2] == 0) {
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n"; out << "int x = (int) round(" << getTempName(node.getChildren()[0], temps) << ");\n";
out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n"; out << "int y = (int) round(" << getTempName(node.getChildren()[1], temps) << ");\n";
out << "int z = (int) round(" << getTempName(node.getChildren()[2], temps) << ");\n"; out << "int z = (int) round(" << getTempName(node.getChildren()[2], temps) << ");\n";
out << "int xsize = (int) params.x;\n"; out << "int xsize = " << paramsInt[0] << ";\n";
out << "int ysize = (int) params.y;\n"; out << "int ysize = " << paramsInt[1] << ";\n";
out << "int zsize = (int) params.z;\n"; out << "int zsize = " << paramsInt[2] << ";\n";
out << "int index = x+(y+z*ysize)*xsize;\n"; out << "int index = x+(y+z*ysize)*xsize;\n";
out << "if (index >= 0 && index < xsize*ysize*zsize)\n"; out << "if (index >= 0 && index < xsize*ysize*zsize)\n";
out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n"; out << nodeNames[j] << " = " << functionNames[i].second << "[index];\n";
...@@ -452,35 +452,41 @@ vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const Tabul ...@@ -452,35 +452,41 @@ vector<float> OpenCLExpressionUtilities::computeFunctionCoefficients(const Tabul
throw OpenMMException("computeFunctionCoefficients: Unknown function type"); throw OpenMMException("computeFunctionCoefficients: Unknown function type");
} }
vector<mm_float4> OpenCLExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) { vector<vector<double> > OpenCLExpressionUtilities::computeFunctionParameters(const vector<const TabulatedFunction*>& functions) {
vector<mm_float4> params(functions.size()); vector<vector<double> > params(functions.size());
for (int i = 0; i < (int) functions.size(); i++) { for (int i = 0; i < (int) functions.size(); i++) {
if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) { if (dynamic_cast<const Continuous1DFunction*>(functions[i]) != NULL) {
const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]); const Continuous1DFunction& fn = dynamic_cast<const Continuous1DFunction&>(*functions[i]);
vector<double> values; vector<double> values;
double min, max; double min, max;
fn.getFunctionParameters(values, min, max); fn.getFunctionParameters(values, min, max);
params[i] = mm_float4((float) min, (float) max, (float) ((values.size()-1)/(max-min)), (float) values.size()-2); params[i].push_back(min);
params[i].push_back(max);
params[i].push_back((values.size()-1)/(max-min));
params[i].push_back(values.size()-2);
} }
else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Discrete1DFunction*>(functions[i]) != NULL) {
const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]); const Discrete1DFunction& fn = dynamic_cast<const Discrete1DFunction&>(*functions[i]);
vector<double> values; vector<double> values;
fn.getFunctionParameters(values); fn.getFunctionParameters(values);
params[i] = mm_float4((float) values.size(), 0.0f, 0.0f, 0.0f); params[i].push_back(values.size());
} }
else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Discrete2DFunction*>(functions[i]) != NULL) {
const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]); const Discrete2DFunction& fn = dynamic_cast<const Discrete2DFunction&>(*functions[i]);
int xsize, ysize; int xsize, ysize;
vector<double> values; vector<double> values;
fn.getFunctionParameters(xsize, ysize, values); fn.getFunctionParameters(xsize, ysize, values);
params[i] = mm_float4(xsize, ysize, 0.0f, 0.0f); params[i].push_back(xsize);
params[i].push_back(ysize);
} }
else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) { else if (dynamic_cast<const Discrete3DFunction*>(functions[i]) != NULL) {
const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]); const Discrete3DFunction& fn = dynamic_cast<const Discrete3DFunction&>(*functions[i]);
int xsize, ysize, zsize; int xsize, ysize, zsize;
vector<double> values; vector<double> values;
fn.getFunctionParameters(xsize, ysize, zsize, values); fn.getFunctionParameters(xsize, ysize, zsize, values);
params[i] = mm_float4(xsize, ysize, zsize, 0.0f); params[i].push_back(xsize);
params[i].push_back(ysize);
params[i].push_back(zsize);
} }
else else
throw OpenMMException("computeFunctionParameters: Unknown function type"); throw OpenMMException("computeFunctionParameters: Unknown function type");
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment