Commit 144fa34a authored by Jérôme Hénin's avatar Jérôme Hénin
Browse files

Fix Lepton optimization with -ffast-math

Lepton::ParsedExpression::getConstantValue() used to return NaN to indicate non-constant
expressions. This can break with -ffast-math optimization where NaN comparisons are
not defined.

This patch defines bool isConstant() to test whether an expression is constant before
calling getConstantValue(). getConstantValue() now throws an exception if called on a
non-constant expression.
parent e936acac
......@@ -116,6 +116,7 @@ private:
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node);
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node);
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable);
static bool isConstant(const ExpressionTreeNode& node);
static double getConstantValue(const ExpressionTreeNode& node);
static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements);
ExpressionTreeNode rootNode;
......
......@@ -121,19 +121,33 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < (int) children.size(); i++)
children[i] = substituteSimplerExpression(node.getChildren()[i]);
// Collect some info on constant expressions in children
bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant?
bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant?
double first, second; // if yes, value of first and second child
if (first_const)
first = getConstantValue(children[0]);
if (second_const)
second = getConstantValue(children[1]);
switch (node.getOperation().getId()) {
case Operation::ADD:
{
double first = getConstantValue(children[0]);
double second = getConstantValue(children[1]);
if (first == 0.0) // Add 0
if (first_const) {
if (first == 0.0) { // Add 0
return children[1];
if (second == 0.0) // Add 0
return children[0];
if (first == first) // Add a constant
} else { // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(first), children[1]);
if (second == second) // Add a constant
}
}
if (second_const) {
if (second == 0.0) { // Add 0
return children[0];
} else { // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(second), children[0]);
}
}
if (children[1].getOperation().getId() == Operation::NEGATE) // a+(-b) = a-b
return ExpressionTreeNode(new Operation::Subtract(), children[0], children[1].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::NEGATE) // (-a)+b = b-a
......@@ -144,34 +158,35 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{
if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Constant(0.0)); // Subtracting anything from itself is 0
double first = getConstantValue(children[0]);
if (first_const) {
if (first == 0.0) // Subtract from 0
return ExpressionTreeNode(new Operation::Negate(), children[1]);
double second = getConstantValue(children[1]);
if (second == 0.0) // Subtract 0
}
if (second_const) {
if (second == 0.0) { // Subtract 0
return children[0];
if (second == second) // Subtract a constant
} else { // Subtract a constant
return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]);
}
}
if (children[1].getOperation().getId() == Operation::NEGATE) // a-(-b) = a+b
return ExpressionTreeNode(new Operation::Add(), children[0], children[1].getChildren()[0]);
break;
}
case Operation::MULTIPLY:
{
double first = getConstantValue(children[0]);
double second = getConstantValue(children[1]);
if (first == 0.0 || second == 0.0) // Multiply by 0
if ((first_const && first == 0.0) || (second_const && second == 0.0)) // Multiply by 0
return ExpressionTreeNode(new Operation::Constant(0.0));
if (first == 1.0) // Multiply by 1
if (first_const && first == 1.0) // Multiply by 1
return children[1];
if (second == 1.0) // Multiply by 1
if (second_const && second == 1.0) // Multiply by 1
return children[0];
if (children[0].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant
if (first_const) { // Multiply by a constant
if (children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(first*dynamic_cast<const Operation::MultiplyConstant*>(&children[1].getOperation())->getValue()), children[1].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(first), children[1]);
}
if (children[1].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant
if (second_const) { // Multiply by a constant
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(second*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(second), children[0]);
......@@ -202,18 +217,16 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{
if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Constant(1.0)); // Dividing anything from itself is 0
double numerator = getConstantValue(children[0]);
if (numerator == 0.0) // 0 divided by something
if (first_const && first == 0.0) // 0 divided by something
return ExpressionTreeNode(new Operation::Constant(0.0));
if (numerator == 1.0) // 1 divided by something
if (first_const && first == 1.0) // 1 divided by something
return ExpressionTreeNode(new Operation::Reciprocal(), children[1]);
double denominator = getConstantValue(children[1]);
if (denominator == 1.0) // Divide by 1
if (second_const && second == 1.0) // Divide by 1
return children[0];
if (children[1].getOperation().getId() == Operation::CONSTANT) {
if (second_const) {
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a divide into one multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()/denominator), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/denominator), children[0]); // Replace a divide with a multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()/second), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/second), children[0]); // Replace a divide with a multiply
}
if (children[0].getOperation().getId() == Operation::NEGATE && children[1].getOperation().getId() == Operation::NEGATE) // The two negations cancel
return ExpressionTreeNode(new Operation::Divide(), children[0].getChildren()[0], children[1].getChildren()[0]);
......@@ -229,34 +242,34 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
}
case Operation::POWER:
{
double base = getConstantValue(children[0]);
if (base == 0.0) // 0 to any power is 0
if (first_const && first == 0.0) // 0 to any power is 0
return ExpressionTreeNode(new Operation::Constant(0.0));
if (base == 1.0) // 1 to any power is 1
if (first_const && first == 1.0) // 1 to any power is 1
return ExpressionTreeNode(new Operation::Constant(1.0));
double exponent = getConstantValue(children[1]);
if (exponent == 0.0) // x^0 = 1
if (second_const) { // Constant exponent
if (second == 0.0) // x^0 = 1
return ExpressionTreeNode(new Operation::Constant(1.0));
if (exponent == 1.0) // x^1 = x
if (second == 1.0) // x^1 = x
return children[0];
if (exponent == -1.0) // x^-1 = recip(x)
if (second == -1.0) // x^-1 = recip(x)
return ExpressionTreeNode(new Operation::Reciprocal(), children[0]);
if (exponent == 2.0) // x^2 = square(x)
if (second == 2.0) // x^2 = square(x)
return ExpressionTreeNode(new Operation::Square(), children[0]);
if (exponent == 3.0) // x^3 = cube(x)
if (second == 3.0) // x^3 = cube(x)
return ExpressionTreeNode(new Operation::Cube(), children[0]);
if (exponent == 0.5) // x^0.5 = sqrt(x)
if (second == 0.5) // x^0.5 = sqrt(x)
return ExpressionTreeNode(new Operation::Sqrt(), children[0]);
if (exponent == exponent) // Constant power
return ExpressionTreeNode(new Operation::PowerConstant(exponent), children[0]);
// Constant power
return ExpressionTreeNode(new Operation::PowerConstant(second), children[0]);
}
break;
}
case Operation::NEGATE:
{
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a negate into a single multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::CONSTANT) // Negate a constant
return ExpressionTreeNode(new Operation::Constant(-getConstantValue(children[0])));
if (first_const) // Negate a constant
return ExpressionTreeNode(new Operation::Constant(-first));
if (children[0].getOperation().getId() == Operation::NEGATE) // The two negations cancel
return children[0].getChildren()[0];
break;
......@@ -265,7 +278,7 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::CONSTANT) // Multiply two constants
if (first_const) // Multiply two constants
return ExpressionTreeNode(new Operation::Constant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*getConstantValue(children[0])));
if (children[0].getOperation().getId() == Operation::NEGATE) // Combine a multiply and a negate into a single multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()), children[0].getChildren()[0]);
......@@ -303,10 +316,15 @@ ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& nod
return node.getOperation().differentiate(node.getChildren(),childDerivs, variable);
}
bool ParsedExpression::isConstant(const ExpressionTreeNode& node) {
return (node.getOperation().getId() == Operation::CONSTANT);
}
double ParsedExpression::getConstantValue(const ExpressionTreeNode& node) {
if (node.getOperation().getId() == Operation::CONSTANT)
if (node.getOperation().getId() != Operation::CONSTANT) {
throw Exception("getConstantValue called on a non-constant ExpressionNode");
}
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue();
return numeric_limits<double>::quiet_NaN();
}
ExpressionProgram ParsedExpression::createProgram() const {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment