Unverified Commit 340bf28a authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2503 from Colvars/lepton_NaN_ffast-math

Fix Lepton optimization with -ffast-math
parents e936acac 144fa34a
......@@ -116,6 +116,7 @@ private:
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node);
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node);
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable);
static bool isConstant(const ExpressionTreeNode& node);
static double getConstantValue(const ExpressionTreeNode& node);
static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements);
ExpressionTreeNode rootNode;
......
......@@ -121,19 +121,33 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < (int) children.size(); i++)
children[i] = substituteSimplerExpression(node.getChildren()[i]);
// Collect some info on constant expressions in children
bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant?
bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant?
double first, second; // if yes, value of first and second child
if (first_const)
first = getConstantValue(children[0]);
if (second_const)
second = getConstantValue(children[1]);
switch (node.getOperation().getId()) {
case Operation::ADD:
{
double first = getConstantValue(children[0]);
double second = getConstantValue(children[1]);
if (first == 0.0) // Add 0
return children[1];
if (second == 0.0) // Add 0
return children[0];
if (first == first) // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(first), children[1]);
if (second == second) // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(second), children[0]);
if (first_const) {
if (first == 0.0) { // Add 0
return children[1];
} else { // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(first), children[1]);
}
}
if (second_const) {
if (second == 0.0) { // Add 0
return children[0];
} else { // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(second), children[0]);
}
}
if (children[1].getOperation().getId() == Operation::NEGATE) // a+(-b) = a-b
return ExpressionTreeNode(new Operation::Subtract(), children[0], children[1].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::NEGATE) // (-a)+b = b-a
......@@ -144,34 +158,35 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{
if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Constant(0.0)); // Subtracting anything from itself is 0
double first = getConstantValue(children[0]);
if (first == 0.0) // Subtract from 0
return ExpressionTreeNode(new Operation::Negate(), children[1]);
double second = getConstantValue(children[1]);
if (second == 0.0) // Subtract 0
return children[0];
if (second == second) // Subtract a constant
return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]);
if (first_const) {
if (first == 0.0) // Subtract from 0
return ExpressionTreeNode(new Operation::Negate(), children[1]);
}
if (second_const) {
if (second == 0.0) { // Subtract 0
return children[0];
} else { // Subtract a constant
return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]);
}
}
if (children[1].getOperation().getId() == Operation::NEGATE) // a-(-b) = a+b
return ExpressionTreeNode(new Operation::Add(), children[0], children[1].getChildren()[0]);
break;
}
case Operation::MULTIPLY:
{
double first = getConstantValue(children[0]);
double second = getConstantValue(children[1]);
if (first == 0.0 || second == 0.0) // Multiply by 0
{
if ((first_const && first == 0.0) || (second_const && second == 0.0)) // Multiply by 0
return ExpressionTreeNode(new Operation::Constant(0.0));
if (first == 1.0) // Multiply by 1
if (first_const && first == 1.0) // Multiply by 1
return children[1];
if (second == 1.0) // Multiply by 1
if (second_const && second == 1.0) // Multiply by 1
return children[0];
if (children[0].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant
if (first_const) { // Multiply by a constant
if (children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(first*dynamic_cast<const Operation::MultiplyConstant*>(&children[1].getOperation())->getValue()), children[1].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(first), children[1]);
}
if (children[1].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant
if (second_const) { // Multiply by a constant
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(second*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(second), children[0]);
......@@ -202,18 +217,16 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{
if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Constant(1.0)); // Dividing anything from itself is 0
double numerator = getConstantValue(children[0]);
if (numerator == 0.0) // 0 divided by something
if (first_const && first == 0.0) // 0 divided by something
return ExpressionTreeNode(new Operation::Constant(0.0));
if (numerator == 1.0) // 1 divided by something
if (first_const && first == 1.0) // 1 divided by something
return ExpressionTreeNode(new Operation::Reciprocal(), children[1]);
double denominator = getConstantValue(children[1]);
if (denominator == 1.0) // Divide by 1
if (second_const && second == 1.0) // Divide by 1
return children[0];
if (children[1].getOperation().getId() == Operation::CONSTANT) {
if (second_const) {
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a divide into one multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()/denominator), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/denominator), children[0]); // Replace a divide with a multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()/second), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/second), children[0]); // Replace a divide with a multiply
}
if (children[0].getOperation().getId() == Operation::NEGATE && children[1].getOperation().getId() == Operation::NEGATE) // The two negations cancel
return ExpressionTreeNode(new Operation::Divide(), children[0].getChildren()[0], children[1].getChildren()[0]);
......@@ -229,34 +242,34 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
}
case Operation::POWER:
{
double base = getConstantValue(children[0]);
if (base == 0.0) // 0 to any power is 0
if (first_const && first == 0.0) // 0 to any power is 0
return ExpressionTreeNode(new Operation::Constant(0.0));
if (base == 1.0) // 1 to any power is 1
return ExpressionTreeNode(new Operation::Constant(1.0));
double exponent = getConstantValue(children[1]);
if (exponent == 0.0) // x^0 = 1
if (first_const && first == 1.0) // 1 to any power is 1
return ExpressionTreeNode(new Operation::Constant(1.0));
if (exponent == 1.0) // x^1 = x
return children[0];
if (exponent == -1.0) // x^-1 = recip(x)
return ExpressionTreeNode(new Operation::Reciprocal(), children[0]);
if (exponent == 2.0) // x^2 = square(x)
return ExpressionTreeNode(new Operation::Square(), children[0]);
if (exponent == 3.0) // x^3 = cube(x)
return ExpressionTreeNode(new Operation::Cube(), children[0]);
if (exponent == 0.5) // x^0.5 = sqrt(x)
return ExpressionTreeNode(new Operation::Sqrt(), children[0]);
if (exponent == exponent) // Constant power
return ExpressionTreeNode(new Operation::PowerConstant(exponent), children[0]);
if (second_const) { // Constant exponent
if (second == 0.0) // x^0 = 1
return ExpressionTreeNode(new Operation::Constant(1.0));
if (second == 1.0) // x^1 = x
return children[0];
if (second == -1.0) // x^-1 = recip(x)
return ExpressionTreeNode(new Operation::Reciprocal(), children[0]);
if (second == 2.0) // x^2 = square(x)
return ExpressionTreeNode(new Operation::Square(), children[0]);
if (second == 3.0) // x^3 = cube(x)
return ExpressionTreeNode(new Operation::Cube(), children[0]);
if (second == 0.5) // x^0.5 = sqrt(x)
return ExpressionTreeNode(new Operation::Sqrt(), children[0]);
// Constant power
return ExpressionTreeNode(new Operation::PowerConstant(second), children[0]);
}
break;
}
case Operation::NEGATE:
{
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a negate into a single multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::CONSTANT) // Negate a constant
return ExpressionTreeNode(new Operation::Constant(-getConstantValue(children[0])));
if (first_const) // Negate a constant
return ExpressionTreeNode(new Operation::Constant(-first));
if (children[0].getOperation().getId() == Operation::NEGATE) // The two negations cancel
return children[0].getChildren()[0];
break;
......@@ -265,7 +278,7 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::CONSTANT) // Multiply two constants
if (first_const) // Multiply two constants
return ExpressionTreeNode(new Operation::Constant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*getConstantValue(children[0])));
if (children[0].getOperation().getId() == Operation::NEGATE) // Combine a multiply and a negate into a single multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()), children[0].getChildren()[0]);
......@@ -303,10 +316,15 @@ ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& nod
return node.getOperation().differentiate(node.getChildren(),childDerivs, variable);
}
bool ParsedExpression::isConstant(const ExpressionTreeNode& node) {
return (node.getOperation().getId() == Operation::CONSTANT);
}
double ParsedExpression::getConstantValue(const ExpressionTreeNode& node) {
if (node.getOperation().getId() == Operation::CONSTANT)
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue();
return numeric_limits<double>::quiet_NaN();
if (node.getOperation().getId() != Operation::CONSTANT) {
throw Exception("getConstantValue called on a non-constant ExpressionNode");
}
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue();
}
ExpressionProgram ParsedExpression::createProgram() const {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment