Commit 144fa34a authored by Jérôme Hénin's avatar Jérôme Hénin
Browse files

Fix Lepton optimization with -ffast-math

Lepton::ParsedExpression::getConstantValue() used to return NaN to indicate non-constant
expressions. This can break with -ffast-math optimization where NaN comparisons are
not defined.

This patch defines bool isConstant() to test whether an expression is constant before
calling getConstantValue(). getConstantValue() now throws an exception if called on a
non-constant expression.
parent e936acac
...@@ -116,6 +116,7 @@ private: ...@@ -116,6 +116,7 @@ private:
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node); static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node);
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node); static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node);
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable); static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable);
static bool isConstant(const ExpressionTreeNode& node);
static double getConstantValue(const ExpressionTreeNode& node); static double getConstantValue(const ExpressionTreeNode& node);
static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements); static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements);
ExpressionTreeNode rootNode; ExpressionTreeNode rootNode;
......
...@@ -121,19 +121,33 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -121,19 +121,33 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
vector<ExpressionTreeNode> children(node.getChildren().size()); vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < (int) children.size(); i++) for (int i = 0; i < (int) children.size(); i++)
children[i] = substituteSimplerExpression(node.getChildren()[i]); children[i] = substituteSimplerExpression(node.getChildren()[i]);
// Collect some info on constant expressions in children
bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant?
bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant?
double first, second; // if yes, value of first and second child
if (first_const)
first = getConstantValue(children[0]);
if (second_const)
second = getConstantValue(children[1]);
switch (node.getOperation().getId()) { switch (node.getOperation().getId()) {
case Operation::ADD: case Operation::ADD:
{ {
double first = getConstantValue(children[0]); if (first_const) {
double second = getConstantValue(children[1]); if (first == 0.0) { // Add 0
if (first == 0.0) // Add 0 return children[1];
return children[1]; } else { // Add a constant
if (second == 0.0) // Add 0 return ExpressionTreeNode(new Operation::AddConstant(first), children[1]);
return children[0]; }
if (first == first) // Add a constant }
return ExpressionTreeNode(new Operation::AddConstant(first), children[1]); if (second_const) {
if (second == second) // Add a constant if (second == 0.0) { // Add 0
return ExpressionTreeNode(new Operation::AddConstant(second), children[0]); return children[0];
} else { // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(second), children[0]);
}
}
if (children[1].getOperation().getId() == Operation::NEGATE) // a+(-b) = a-b if (children[1].getOperation().getId() == Operation::NEGATE) // a+(-b) = a-b
return ExpressionTreeNode(new Operation::Subtract(), children[0], children[1].getChildren()[0]); return ExpressionTreeNode(new Operation::Subtract(), children[0], children[1].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::NEGATE) // (-a)+b = b-a if (children[0].getOperation().getId() == Operation::NEGATE) // (-a)+b = b-a
...@@ -144,34 +158,35 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -144,34 +158,35 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{ {
if (children[0] == children[1]) if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Constant(0.0)); // Subtracting anything from itself is 0 return ExpressionTreeNode(new Operation::Constant(0.0)); // Subtracting anything from itself is 0
double first = getConstantValue(children[0]); if (first_const) {
if (first == 0.0) // Subtract from 0 if (first == 0.0) // Subtract from 0
return ExpressionTreeNode(new Operation::Negate(), children[1]); return ExpressionTreeNode(new Operation::Negate(), children[1]);
double second = getConstantValue(children[1]); }
if (second == 0.0) // Subtract 0 if (second_const) {
return children[0]; if (second == 0.0) { // Subtract 0
if (second == second) // Subtract a constant return children[0];
return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]); } else { // Subtract a constant
return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]);
}
}
if (children[1].getOperation().getId() == Operation::NEGATE) // a-(-b) = a+b if (children[1].getOperation().getId() == Operation::NEGATE) // a-(-b) = a+b
return ExpressionTreeNode(new Operation::Add(), children[0], children[1].getChildren()[0]); return ExpressionTreeNode(new Operation::Add(), children[0], children[1].getChildren()[0]);
break; break;
} }
case Operation::MULTIPLY: case Operation::MULTIPLY:
{ {
double first = getConstantValue(children[0]); if ((first_const && first == 0.0) || (second_const && second == 0.0)) // Multiply by 0
double second = getConstantValue(children[1]);
if (first == 0.0 || second == 0.0) // Multiply by 0
return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Constant(0.0));
if (first == 1.0) // Multiply by 1 if (first_const && first == 1.0) // Multiply by 1
return children[1]; return children[1];
if (second == 1.0) // Multiply by 1 if (second_const && second == 1.0) // Multiply by 1
return children[0]; return children[0];
if (children[0].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant if (first_const) { // Multiply by a constant
if (children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one if (children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(first*dynamic_cast<const Operation::MultiplyConstant*>(&children[1].getOperation())->getValue()), children[1].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(first*dynamic_cast<const Operation::MultiplyConstant*>(&children[1].getOperation())->getValue()), children[1].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(first), children[1]); return ExpressionTreeNode(new Operation::MultiplyConstant(first), children[1]);
} }
if (children[1].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant if (second_const) { // Multiply by a constant
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(second*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(second*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(second), children[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(second), children[0]);
...@@ -202,18 +217,16 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -202,18 +217,16 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{ {
if (children[0] == children[1]) if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Constant(1.0)); // Dividing anything from itself is 0 return ExpressionTreeNode(new Operation::Constant(1.0)); // Dividing anything from itself is 0
double numerator = getConstantValue(children[0]); if (first_const && first == 0.0) // 0 divided by something
if (numerator == 0.0) // 0 divided by something
return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Constant(0.0));
if (numerator == 1.0) // 1 divided by something if (first_const && first == 1.0) // 1 divided by something
return ExpressionTreeNode(new Operation::Reciprocal(), children[1]); return ExpressionTreeNode(new Operation::Reciprocal(), children[1]);
double denominator = getConstantValue(children[1]); if (second_const && second == 1.0) // Divide by 1
if (denominator == 1.0) // Divide by 1
return children[0]; return children[0];
if (children[1].getOperation().getId() == Operation::CONSTANT) { if (second_const) {
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a divide into one multiply if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a divide into one multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()/denominator), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()/second), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/denominator), children[0]); // Replace a divide with a multiply return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/second), children[0]); // Replace a divide with a multiply
} }
if (children[0].getOperation().getId() == Operation::NEGATE && children[1].getOperation().getId() == Operation::NEGATE) // The two negations cancel if (children[0].getOperation().getId() == Operation::NEGATE && children[1].getOperation().getId() == Operation::NEGATE) // The two negations cancel
return ExpressionTreeNode(new Operation::Divide(), children[0].getChildren()[0], children[1].getChildren()[0]); return ExpressionTreeNode(new Operation::Divide(), children[0].getChildren()[0], children[1].getChildren()[0]);
...@@ -229,34 +242,34 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -229,34 +242,34 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
} }
case Operation::POWER: case Operation::POWER:
{ {
double base = getConstantValue(children[0]); if (first_const && first == 0.0) // 0 to any power is 0
if (base == 0.0) // 0 to any power is 0
return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Constant(0.0));
if (base == 1.0) // 1 to any power is 1 if (first_const && first == 1.0) // 1 to any power is 1
return ExpressionTreeNode(new Operation::Constant(1.0));
double exponent = getConstantValue(children[1]);
if (exponent == 0.0) // x^0 = 1
return ExpressionTreeNode(new Operation::Constant(1.0)); return ExpressionTreeNode(new Operation::Constant(1.0));
if (exponent == 1.0) // x^1 = x if (second_const) { // Constant exponent
return children[0]; if (second == 0.0) // x^0 = 1
if (exponent == -1.0) // x^-1 = recip(x) return ExpressionTreeNode(new Operation::Constant(1.0));
return ExpressionTreeNode(new Operation::Reciprocal(), children[0]); if (second == 1.0) // x^1 = x
if (exponent == 2.0) // x^2 = square(x) return children[0];
return ExpressionTreeNode(new Operation::Square(), children[0]); if (second == -1.0) // x^-1 = recip(x)
if (exponent == 3.0) // x^3 = cube(x) return ExpressionTreeNode(new Operation::Reciprocal(), children[0]);
return ExpressionTreeNode(new Operation::Cube(), children[0]); if (second == 2.0) // x^2 = square(x)
if (exponent == 0.5) // x^0.5 = sqrt(x) return ExpressionTreeNode(new Operation::Square(), children[0]);
return ExpressionTreeNode(new Operation::Sqrt(), children[0]); if (second == 3.0) // x^3 = cube(x)
if (exponent == exponent) // Constant power return ExpressionTreeNode(new Operation::Cube(), children[0]);
return ExpressionTreeNode(new Operation::PowerConstant(exponent), children[0]); if (second == 0.5) // x^0.5 = sqrt(x)
return ExpressionTreeNode(new Operation::Sqrt(), children[0]);
// Constant power
return ExpressionTreeNode(new Operation::PowerConstant(second), children[0]);
}
break; break;
} }
case Operation::NEGATE: case Operation::NEGATE:
{ {
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a negate into a single multiply if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a negate into a single multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::CONSTANT) // Negate a constant if (first_const) // Negate a constant
return ExpressionTreeNode(new Operation::Constant(-getConstantValue(children[0]))); return ExpressionTreeNode(new Operation::Constant(-first));
if (children[0].getOperation().getId() == Operation::NEGATE) // The two negations cancel if (children[0].getOperation().getId() == Operation::NEGATE) // The two negations cancel
return children[0].getChildren()[0]; return children[0].getChildren()[0];
break; break;
...@@ -265,7 +278,7 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -265,7 +278,7 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{ {
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::CONSTANT) // Multiply two constants if (first_const) // Multiply two constants
return ExpressionTreeNode(new Operation::Constant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*getConstantValue(children[0]))); return ExpressionTreeNode(new Operation::Constant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*getConstantValue(children[0])));
if (children[0].getOperation().getId() == Operation::NEGATE) // Combine a multiply and a negate into a single multiply if (children[0].getOperation().getId() == Operation::NEGATE) // Combine a multiply and a negate into a single multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()), children[0].getChildren()[0]);
...@@ -303,10 +316,15 @@ ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& nod ...@@ -303,10 +316,15 @@ ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& nod
return node.getOperation().differentiate(node.getChildren(),childDerivs, variable); return node.getOperation().differentiate(node.getChildren(),childDerivs, variable);
} }
bool ParsedExpression::isConstant(const ExpressionTreeNode& node) {
return (node.getOperation().getId() == Operation::CONSTANT);
}
double ParsedExpression::getConstantValue(const ExpressionTreeNode& node) { double ParsedExpression::getConstantValue(const ExpressionTreeNode& node) {
if (node.getOperation().getId() == Operation::CONSTANT) if (node.getOperation().getId() != Operation::CONSTANT) {
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue(); throw Exception("getConstantValue called on a non-constant ExpressionNode");
return numeric_limits<double>::quiet_NaN(); }
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue();
} }
ExpressionProgram ParsedExpression::createProgram() const { ExpressionProgram ParsedExpression::createProgram() const {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment