Commit 1e52c490 authored by Peter Eastman's avatar Peter Eastman
Browse files

Optimization to evaluating integral powers in custom expressions

parent d0ce27f1
...@@ -222,20 +222,48 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre ...@@ -222,20 +222,48 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out << "1.0f"; out << "1.0f";
else if (exponent == (int) exponent) { else if (exponent == (int) exponent) {
out << "0.0f;\n"; out << "0.0f;\n";
temps.push_back(make_pair(node, name));
hasRecordedNode = true;
// If multiple integral powers of the same base are needed, it's faster to calculate all of them
// at once, so check to see if others are also needed.
map<int, const ExpressionTreeNode*> powers;
powers[(int) exponent] = &node;
for (int j = 0; j < (int) allExpressions.size(); j++)
findRelatedPowers(node, allExpressions[j].getRootNode(), powers);
vector<int> exponents;
vector<string> names;
vector<bool> hasAssigned(powers.size(), false);
exponents.push_back((int) fabs(exponent));
names.push_back(name);
for (map<int, const ExpressionTreeNode*>::const_iterator iter = powers.begin(); iter != powers.end(); ++iter) {
if (iter->first != exponent) {
exponents.push_back(abs(iter->first));
string name2 = prefix+intToString(temps.size());
names.push_back(name2);
temps.push_back(make_pair(*iter->second, name2));
out << "float " << name2 << " = 0.0f;\n";
}
}
out << "{\n"; out << "{\n";
out << "float multiplier = " << (exponent < 0.0 ? "1.0f/" : "") << getTempName(node.getChildren()[0], temps) << ";\n"; out << "float multiplier = " << (exponent < 0.0 ? "1.0f/" : "") << getTempName(node.getChildren()[0], temps) << ";\n";
int exp = (int) fabs(exponent); bool done = false;
bool hasAssigned = false; while (!done) {
while (exp != 0) { done = true;
if (exp%2 == 1) { for (int i = 0; i < exponents.size(); i++) {
if (!hasAssigned) if (exponents[i]%2 == 1) {
out << name << " = multiplier;\n"; if (!hasAssigned[i])
out << names[i] << " = multiplier;\n";
else else
out << name << " *= multiplier;\n"; out << names[i] << " *= multiplier;\n";
hasAssigned = true; hasAssigned[i] = true;
}
exponents[i] >>= 1;
if (exponents[i] != 0)
done = false;
} }
exp >>= 1; if (!done)
if (exp != 0)
out << "multiplier *= multiplier;\n"; out << "multiplier *= multiplier;\n";
} }
out << "}"; out << "}";
...@@ -273,3 +301,17 @@ void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTr ...@@ -273,3 +301,17 @@ void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTr
for (int i = 0; i < (int) searchNode.getChildren().size(); i++) for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], valueNode, derivNode); findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], valueNode, derivNode);
} }
void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
if (searchNode.getOperation().getId() == Operation::POWER_CONSTANT && node.getChildren()[0] == searchNode.getChildren()[0]) {
int power = dynamic_cast<const Operation::PowerConstant*>(&searchNode.getOperation())->getValue();
if (powers.find(power) != powers.end())
return; // This power is already in the map.
if (powers.begin()->first*power < 0)
return; // All powers must have the same sign.
powers[power] = &searchNode;
}
else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedPowers(node, searchNode.getChildren()[i], powers);
}
...@@ -62,6 +62,8 @@ private: ...@@ -62,6 +62,8 @@ private:
static std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps); static std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
static void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, static void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode); const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode);
static void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers);
}; };
} // namespace OpenMM } // namespace OpenMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment