Commit 1e52c490 authored by Peter Eastman's avatar Peter Eastman
Browse files

Optimization to evaluating integral powers in custom expressions

parent d0ce27f1
......@@ -222,20 +222,48 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out << "1.0f";
else if (exponent == (int) exponent) {
out << "0.0f;\n";
temps.push_back(make_pair(node, name));
hasRecordedNode = true;
// If multiple integral powers of the same base are needed, it's faster to calculate all of them
// at once, so check to see if others are also needed.
map<int, const ExpressionTreeNode*> powers;
powers[(int) exponent] = &node;
for (int j = 0; j < (int) allExpressions.size(); j++)
findRelatedPowers(node, allExpressions[j].getRootNode(), powers);
vector<int> exponents;
vector<string> names;
vector<bool> hasAssigned(powers.size(), false);
exponents.push_back((int) fabs(exponent));
names.push_back(name);
for (map<int, const ExpressionTreeNode*>::const_iterator iter = powers.begin(); iter != powers.end(); ++iter) {
if (iter->first != exponent) {
exponents.push_back(abs(iter->first));
string name2 = prefix+intToString(temps.size());
names.push_back(name2);
temps.push_back(make_pair(*iter->second, name2));
out << "float " << name2 << " = 0.0f;\n";
}
}
out << "{\n";
out << "float multiplier = " << (exponent < 0.0 ? "1.0f/" : "") << getTempName(node.getChildren()[0], temps) << ";\n";
int exp = (int) fabs(exponent);
bool hasAssigned = false;
while (exp != 0) {
if (exp%2 == 1) {
if (!hasAssigned)
out << name << " = multiplier;\n";
else
out << name << " *= multiplier;\n";
hasAssigned = true;
bool done = false;
while (!done) {
done = true;
for (int i = 0; i < exponents.size(); i++) {
if (exponents[i]%2 == 1) {
if (!hasAssigned[i])
out << names[i] << " = multiplier;\n";
else
out << names[i] << " *= multiplier;\n";
hasAssigned[i] = true;
}
exponents[i] >>= 1;
if (exponents[i] != 0)
done = false;
}
exp >>= 1;
if (exp != 0)
if (!done)
out << "multiplier *= multiplier;\n";
}
out << "}";
......@@ -273,3 +301,17 @@ void OpenCLExpressionUtilities::findRelatedTabulatedFunctions(const ExpressionTr
for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedTabulatedFunctions(node, searchNode.getChildren()[i], valueNode, derivNode);
}
void OpenCLExpressionUtilities::findRelatedPowers(const ExpressionTreeNode& node, const ExpressionTreeNode& searchNode, map<int, const ExpressionTreeNode*>& powers) {
if (searchNode.getOperation().getId() == Operation::POWER_CONSTANT && node.getChildren()[0] == searchNode.getChildren()[0]) {
int power = dynamic_cast<const Operation::PowerConstant*>(&searchNode.getOperation())->getValue();
if (powers.find(power) != powers.end())
return; // This power is already in the map.
if (powers.begin()->first*power < 0)
return; // All powers must have the same sign.
powers[power] = &searchNode;
}
else
for (int i = 0; i < (int) searchNode.getChildren().size(); i++)
findRelatedPowers(node, searchNode.getChildren()[i], powers);
}
......@@ -62,6 +62,8 @@ private:
static std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
static void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
const Lepton::ExpressionTreeNode*& valueNode, const Lepton::ExpressionTreeNode*& derivNode);
static void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers);
};
} // namespace OpenMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment