"wrappers/vscode:/vscode.git/clone" did not exist on "eb9f735ae9d1f4bf62f574dbfb7433ef83ae6197"
Commit d0ce27f1 authored by Peter Eastman's avatar Peter Eastman
Browse files

Optimization to evaluating tabulated functions

parent e1a7bc3d
......@@ -48,22 +48,27 @@ static string intToString(int value) {
string OpenCLExpressionUtilities::createExpressions(const map<string, ParsedExpression>& expressions, const map<string, string>& variables,
const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
stringstream out;
vector<ParsedExpression> allExpressions;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter)
allExpressions.push_back(iter->second);
vector<pair<ExpressionTreeNode, string> > temps;
for (map<string, ParsedExpression>::const_iterator iter = expressions.begin(); iter != expressions.end(); ++iter) {
processExpression(out, iter->second.getRootNode(), temps, variables, functions, prefix, functionParams);
processExpression(out, iter->second.getRootNode(), temps, variables, functions, prefix, functionParams, allExpressions);
out << iter->first << getTempName(iter->second.getRootNode(), temps) << ";\n";
}
return out.str();
}
void OpenCLExpressionUtilities::processExpression(stringstream& out, const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& temps,
const map<string, string>& variables, const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams) {
const map<string, string>& variables, const vector<pair<string, string> >& functions, const string& prefix, const string& functionParams,
const vector<ParsedExpression>& allExpressions) {
for (int i = 0; i < (int) temps.size(); i++)
if (temps[i].first == node)
return;
for (int i = 0; i < (int) node.getChildren().size(); i++)
processExpression(out, node.getChildren()[i], temps, variables, functions, prefix, functionParams);
processExpression(out, node.getChildren()[i], temps, variables, functions, prefix, functionParams, allExpressions);
string name = prefix+intToString(temps.size());
bool hasRecordedNode = false;
out << "float " << name << " = ";
switch (node.getOperation().getId()) {
......@@ -85,7 +90,32 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
;
if (i == functions.size())
throw OpenMMException("Unknown function in expression: "+node.getOperation().getName());
bool isDeriv = (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 1);
out << "0.0f;\n";
temps.push_back(make_pair(node, name));
hasRecordedNode = true;
// If both the value and derivative of the function are needed, it's faster to calculate them both
// at once, so check to see if both are needed.
const ExpressionTreeNode* valueNode = NULL;
const ExpressionTreeNode* derivNode = NULL;
for (int j = 0; j < (int) allExpressions.size(); j++)
findRelatedTabulatedFunctions(node, allExpressions[j].getRootNode(), valueNode, derivNode);
string valueName = name;
string derivName = name;
if (valueNode != NULL && derivNode != NULL) {
string name2 = prefix+intToString(temps.size());
out << "float " << name2 << " = 0.0f;\n";
if (isDeriv) {
valueName = name2;
temps.push_back(make_pair(*valueNode, name2));
}
else {
derivName = name2;
temps.push_back(make_pair(*derivNode, name2));
}
}
out << "{\n";
out << "float4 params = " << functionParams << "[" << i << "];\n";
out << "float x = " << getTempName(node.getChildren()[0], temps) << ";\n";
......@@ -93,10 +123,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out << "int index = (int) (floor((x-params.x)*params.z));\n";
out << "float4 coeff = " << functions[i].second << "[index];\n";
out << "x = (x-params.x)*params.z-index;\n";
if (dynamic_cast<const Operation::Custom*>(&node.getOperation())->getDerivOrder()[0] == 0)
out << name << " = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));\n";
else
out << name << " = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;\n";
if (valueNode != NULL)
out << valueName << " = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));\n";
if (derivNode != NULL)
out << derivName << " = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;\n";
out << "}\n";
out << "}";
break;
......@@ -218,6 +248,7 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
throw OpenMMException("Internal error: Unknown operation in user-defined expression: "+node.getOperation().getName());
}
out << ";\n";
if (!hasRecordedNode)
temps.push_back(make_pair(node, name));
}
......@@ -229,3 +260,16 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
out << "Internal error: No temporary variable for expression node: " << node;
throw OpenMMException(out.str());
}
void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode,
const ExpressionTreeNode*& valueNode, const ExpressionTreeNode*& derivNode) {
if (searchNode.getOperation().getId() == Operation::CUSTOM && node.getChildren()[0] == searchNode.getChildren()[0]) {
if (dynamic_cast<const Operation::Custom*>(&searchNode.getOperation())->getDerivOrder()[0] == 0)
valueNode = &searchNode;
else
derivNode = &searchNode;
}
else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], valueNode, derivNode);
}
......@@ -43,13 +43,25 @@ namespace OpenMM {
class OpenCLExpressionUtilities {
public:
/**
* Generate the source code for calculating a set of expressions.
*
* @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable that may appear in the expressions
* @param functions defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
*/
static std::string createExpressions(const std::map<std::string, Lepton::ParsedExpression>& expressions, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams);
private:
static void processExpression(std::stringstream& out, const Lepton::ExpressionTreeNode& node,
std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps, const std::map<std::string, std::string>& variables,
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams);
const std::vector<std::pair<std::string, std::string> >& functions, const std::string& prefix, const std::string& functionParams,
const std::vector<Lepton::ParsedExpression>& allExpressions);
static std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
static void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode);
};
} // namespace OpenMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment