Commit 803b8678 authored by Peter Eastman's avatar Peter Eastman
Browse files

Optimizations to expression processing in CustomGBForce

parent f11a6aea
......@@ -102,6 +102,13 @@ public:
virtual bool isInfixOperator() const {
return false;
}
/**
* Get whether this is a symmetric binary operation, such that exchanging its arguments
* does not affect the result.
*/
virtual bool isSymmetric() const {
return false;
}
virtual bool operator!=(const Operation& op) const {
return op.getId() != getId();
}
......@@ -273,6 +280,9 @@ public:
bool isInfixOperator() const {
return true;
}
bool isSymmetric() const {
return true;
}
};
class Operation::Subtract : public Operation {
......@@ -323,6 +333,9 @@ public:
bool isInfixOperator() const {
return true;
}
bool isSymmetric() const {
return true;
}
};
class Operation::Divide : public Operation {
......@@ -910,6 +923,9 @@ public:
const PowerConstant* o = dynamic_cast<const PowerConstant*>(&op);
return (o == NULL || o->value != value);
}
bool isInfixOperator() const {
return true;
}
private:
double value;
};
......
......@@ -97,6 +97,14 @@ public:
* Create an ExpressionProgram that represents the same calculation as this expression.
*/
ExpressionProgram createProgram() const;
/**
* Create a new ParsedExpression which is identical to this one, except that the names of some
* variables have been changed.
*
* @param replacements a map whose keys are the names of variables, and whose values are the
* new names to replace them with
*/
ParsedExpression renameVariables(const std::map<std::string, std::string>& replacements) const;
private:
static double evaluate(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
......@@ -104,6 +112,7 @@ private:
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node);
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable);
static double getConstantValue(const ExpressionTreeNode& node);
static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements);
ExpressionTreeNode rootNode;
};
......
......@@ -73,6 +73,13 @@ ExpressionTreeNode::~ExpressionTreeNode() {
bool ExpressionTreeNode::operator!=(const ExpressionTreeNode& node) const {
if (node.getOperation() != getOperation())
return true;
if (getOperation().isSymmetric() && getChildren().size() == 2) {
if (getChildren()[0] == node.getChildren()[0] && getChildren()[1] == node.getChildren()[1])
return false;
if (getChildren()[0] == node.getChildren()[1] && getChildren()[1] == node.getChildren()[0])
return false;
return true;
}
for (int i = 0; i < (int) getChildren().size(); i++)
if (getChildren()[i] != node.getChildren()[i])
return true;
......
......@@ -54,7 +54,7 @@ double ParsedExpression::evaluate() const {
return evaluate(getRootNode(), map<string, double>());
}
double ParsedExpression::evaluate(const std::map<std::string, double>& variables) const {
double ParsedExpression::evaluate(const map<string, double>& variables) const {
return evaluate(getRootNode(), variables);
}
......@@ -133,10 +133,16 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
return ExpressionTreeNode(new Operation::AddConstant(first), children[1]);
if (second == second) // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(second), children[0]);
if (children[1].getOperation().getId() == Operation::NEGATE) // a+(-b) = a-b
return ExpressionTreeNode(new Operation::Subtract(), children[0], children[1].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::NEGATE) // (-a)+b = b-a
return ExpressionTreeNode(new Operation::Subtract(), children[1], children[0].getChildren()[0]);
break;
}
case Operation::SUBTRACT:
{
if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Constant(0.0)); // Subtracting anything from itself is 0
double first = getConstantValue(children[0]);
if (first == 0.0) // Subtract from 0
return ExpressionTreeNode(new Operation::Negate(), children[1]);
......@@ -145,6 +151,8 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
return children[0];
if (second == second) // Subtract a constant
return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]);
if (children[1].getOperation().getId() == Operation::NEGATE) // a-(-b) = a+b
return ExpressionTreeNode(new Operation::Add(), children[0], children[1].getChildren()[0]);
break;
}
case Operation::MULTIPLY:
......@@ -177,10 +185,16 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
return ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), children[0].getChildren()[0], children[1]));
if (children[1].getOperation().getId() == Operation::NEGATE) // Pull the negation out so it can possibly be optimized further
return ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), children[0], children[1].getChildren()[0]));
if (children[1].getOperation().getId() == Operation::RECIPROCAL) // a*(1/b) = a/b
return ExpressionTreeNode(new Operation::Divide(), children[0], children[1].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::RECIPROCAL) // (1/a)*b = b/a
return ExpressionTreeNode(new Operation::Divide(), children[1], children[0].getChildren()[0]);
break;
}
case Operation::DIVIDE:
{
if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Constant(1.0)); // Dividing anything from itself is 0
double numerator = getConstantValue(children[0]);
if (numerator == 0.0) // 0 divided by something
return ExpressionTreeNode(new Operation::Constant(0.0));
......@@ -202,6 +216,8 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
return ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Divide(), children[0].getChildren()[0], children[1]));
if (children[1].getOperation().getId() == Operation::NEGATE) // Pull the negation out so it can possibly be optimized further
return ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Divide(), children[0], children[1].getChildren()[0]));
if (children[1].getOperation().getId() == Operation::RECIPROCAL) // a/(1/b) = a*b
return ExpressionTreeNode(new Operation::Multiply(), children[0], children[1].getChildren()[0]);
break;
}
case Operation::POWER:
......@@ -251,11 +267,11 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
return ExpressionTreeNode(node.getOperation().clone(), children);
}
ParsedExpression ParsedExpression::differentiate(const std::string& variable) const {
ParsedExpression ParsedExpression::differentiate(const string& variable) const {
return differentiate(getRootNode(), variable);
}
ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const std::string& variable) {
ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable) {
vector<ExpressionTreeNode> childDerivs(node.getChildren().size());
for (int i = 0; i < (int) childDerivs.size(); i++)
childDerivs[i] = differentiate(node.getChildren()[i], variable);
......@@ -272,10 +288,29 @@ ExpressionProgram ParsedExpression::createProgram() const {
return ExpressionProgram(*this);
}
ParsedExpression ParsedExpression::renameVariables(const map<string, string>& replacements) const {
return ParsedExpression(renameNodeVariables(getRootNode(), replacements));
}
ExpressionTreeNode ParsedExpression::renameNodeVariables(const ExpressionTreeNode& node, const map<string, string>& replacements) {
if (node.getOperation().getId() == Operation::VARIABLE) {
map<string, string>::const_iterator replace = replacements.find(node.getOperation().getName());
if (replace != replacements.end())
return ExpressionTreeNode(new Operation::Variable(replace->second));
}
vector<ExpressionTreeNode> children;
for (int i = 0; i < (int) node.getChildren().size(); i++)
children.push_back(renameNodeVariables(node.getChildren()[i], replacements));
return ExpressionTreeNode(node.getOperation().clone(), children);
}
ostream& Lepton::operator<<(ostream& out, const ExpressionTreeNode& node) {
if (node.getOperation().isInfixOperator() && node.getChildren().size() == 2) {
out << "(" << node.getChildren()[0] << ")" << node.getOperation().getName() << "(" << node.getChildren()[1] << ")";
}
else if (node.getOperation().isInfixOperator() && node.getChildren().size() == 1) {
out << "(" << node.getChildren()[0] << ")" << node.getOperation().getName();
}
else {
out << node.getOperation().getName();
if (node.getChildren().size() > 0) {
......
......@@ -1356,30 +1356,27 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
{
// Create the N2 value kernel.
map<string, string> variables1;
map<string, string> variables2;
variables1["r"] = "r";
variables2["r"] = "r";
map<string, string> variables;
map<string, string> rename;
variables["r"] = "r";
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
const string& name = force.getPerParticleParameterName(i);
variables1[name+"1"] = "params"+params->getParameterSuffix(i, "1");
variables1[name+"2"] = "params"+params->getParameterSuffix(i, "2");
variables2[name+"2"] = "params"+params->getParameterSuffix(i, "1");
variables2[name+"1"] = "params"+params->getParameterSuffix(i, "2");
variables[name+"1"] = "params"+params->getParameterSuffix(i, "1");
variables[name+"2"] = "params"+params->getParameterSuffix(i, "2");
rename[name+"1"] = name+"2";
rename[name+"2"] = name+"1";
}
for (int i = 0; i < force.getNumGlobalParameters(); i++) {
const string& name = force.getGlobalParameterName(i);
string value = "globals["+intToString(i)+"]";
variables1[name] = value;
variables2[name] = value;
variables[name] = value;
}
map<string, Lepton::ParsedExpression> n2ValueExpressions;
stringstream n2ValueSource;
n2ValueExpressions["tempValue1 = "] = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
n2ValueSource << OpenCLExpressionUtilities::createExpressions(n2ValueExpressions, variables1, functionDefinitions, "tempA", prefix+"functionParams");
n2ValueExpressions.clear();
n2ValueExpressions["tempValue2 = "] = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
n2ValueSource << OpenCLExpressionUtilities::createExpressions(n2ValueExpressions, variables2, functionDefinitions, "tempB", prefix+"functionParams");
Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
n2ValueExpressions["tempValue1 = "] = ex;
n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename);
n2ValueSource << OpenCLExpressionUtilities::createExpressions(n2ValueExpressions, variables, functionDefinitions, "temp", prefix+"functionParams");
map<string, string> replacements;
replacements["COMPUTE_VALUE"] = n2ValueSource.str();
stringstream extraArgs, loadLocal1, loadLocal2, load1, load2;
......@@ -1627,48 +1624,48 @@ void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const Custo
string value = "globals["+intToString(i)+"]";
globalVariables[name] = prefix+value;
}
map<string, string> variables1 = globalVariables;
map<string, string> variables2 = globalVariables;
variables1["r"] = "r";
variables2["r"] = "r";
map<string, string> variables = globalVariables;
map<string, string> rename;
variables["r"] = "r";
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
const string& name = force.getPerParticleParameterName(i);
variables1[name+"1"] = prefix+"params"+params->getParameterSuffix(i, "1");
variables1[name+"2"] = prefix+"params"+params->getParameterSuffix(i, "2");
variables2[name+"2"] = prefix+"params"+params->getParameterSuffix(i, "1");
variables2[name+"1"] = prefix+"params"+params->getParameterSuffix(i, "2");
variables[name+"1"] = prefix+"params"+params->getParameterSuffix(i, "1");
variables[name+"2"] = prefix+"params"+params->getParameterSuffix(i, "2");
rename[name+"1"] = name+"2";
rename[name+"2"] = name+"1";
}
map<string, Lepton::ParsedExpression> derivExpressions;
stringstream chainSource;
Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize();
derivExpressions["float dVdR1 = "] = dVdR;
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables1, functionDefinitions, prefix+"tempA0_", prefix+"functionParams");
derivExpressions.clear();
derivExpressions["float dVdR2 = "] = dVdR;
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables2, functionDefinitions, prefix+"tempB0_", prefix+"functionParams");
derivExpressions["float dVdR2 = "] = dVdR.renameVariables(rename);
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp0_", prefix+"functionParams");
chainSource << "tempForce -= dVdR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "1") << ";\n";
chainSource << "tempForce -= dVdR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "2") << ";\n";
variables1 = globalVariables;
variables2 = globalVariables;
variables = globalVariables;
map<string, string> rename1;
map<string, string> rename2;
for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
const string& name = force.getPerParticleParameterName(i);
variables1[name] = prefix+"params"+params->getParameterSuffix(i, "1");
variables2[name] = prefix+"params"+params->getParameterSuffix(i, "2");
variables[name+"1"] = prefix+"params"+params->getParameterSuffix(i, "1");
variables[name+"2"] = prefix+"params"+params->getParameterSuffix(i, "2");
rename1[name] = name+"1";
rename2[name] = name+"2";
}
for (int i = 0; i < force.getNumComputedValues(); i++) {
const string& name = computedValueNames[i];
variables1[name] = prefix+"values"+computedValues->getParameterSuffix(i, "1");
variables2[name] = prefix+"values"+computedValues->getParameterSuffix(i, "2");
variables[name+"1"] = prefix+"values"+computedValues->getParameterSuffix(i, "1");
variables[name+"2"] = prefix+"values"+computedValues->getParameterSuffix(i, "2");
rename1[name] = name+"1";
rename2[name] = name+"2";
if (i == 0)
continue;
Lepton::ParsedExpression dVdV = Lepton::Parser::parse(computedValueExpressions[1], functions).differentiate(computedValueNames[i-1]).optimize();
string var = "dV"+intToString(i+1)+"dV"+intToString(i)+"_";
derivExpressions.clear();
derivExpressions["float "+var+"1 = "] = dVdV;
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables1, functionDefinitions, prefix+"tempA"+intToString(i)+"_", prefix+"functionParams");
derivExpressions.clear();
derivExpressions["float "+var+"2 = "] = dVdV;
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables2, functionDefinitions, prefix+"tempB"+intToString(i)+"_", prefix+"functionParams");
derivExpressions["float "+var+"1 = "] = dVdV.renameVariables(rename1);
derivExpressions["float "+var+"2 = "] = dVdV.renameVariables(rename2);
chainSource << OpenCLExpressionUtilities::createExpressions(derivExpressions, variables, functionDefinitions, prefix+"temp"+intToString(i)+"_", prefix+"functionParams");
chainSource << "dVdR1 *= "+var+"1;\n";
chainSource << "dVdR2 *= "+var+"2;\n";
chainSource << "tempForce -= dVdR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "1") << ";\n";
......
......@@ -85,6 +85,16 @@ void verifyEvaluation(const string& expression, double x, double y, double expec
ExpressionProgram program = parsed.createProgram();
value = program.evaluate(variables);
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
// Make sure that variable renaming works.
variables.clear();
variables["w"] = x;
variables["y"] = y;
map<string, string> replacements;
replacements["x"] = "w";
value = parsed.renameVariables(replacements).evaluate(variables);
ASSERT_EQUAL_TOL(expectedValue, value, 1e-10);
}
/**
......@@ -184,6 +194,9 @@ int main() {
verifyEvaluation("5*(-x)/(-y)", 1.0, 4.0, 1.25);
verifyEvaluation("5*(-x)/(y)", 1.0, 4.0, -1.25);
verifyEvaluation("5*(x)/(-y)", 1.0, 4.0, -1.25);
verifyEvaluation("x+(-y)", 1.0, 4.0, -3.0);
verifyEvaluation("(-x)+y", 1.0, 4.0, 3.0);
verifyEvaluation("x/(1/y)", 1.0, 4.0, 4.0);
verifyEvaluation("x*w; w = 5", 3.0, 1.0, 15.0);
verifyEvaluation("a+b^2;a=x-b;b=3*y", 2.0, 3.0, 74.0);
verifyInvalidExpression("1..2");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment