"vscode:/vscode.git/clone" did not exist on "cd371d31fc7d62d6d28b5a803a31a3a1accc3d35"
Unverified Commit 340bf28a authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2503 from Colvars/lepton_NaN_ffast-math

Fix Lepton optimization with -ffast-math
parents e936acac 144fa34a
...@@ -116,6 +116,7 @@ private: ...@@ -116,6 +116,7 @@ private:
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node); static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node);
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node); static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node);
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable); static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable);
static bool isConstant(const ExpressionTreeNode& node);
static double getConstantValue(const ExpressionTreeNode& node); static double getConstantValue(const ExpressionTreeNode& node);
static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements); static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements);
ExpressionTreeNode rootNode; ExpressionTreeNode rootNode;
......
...@@ -121,19 +121,33 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -121,19 +121,33 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
vector<ExpressionTreeNode> children(node.getChildren().size()); vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < (int) children.size(); i++) for (int i = 0; i < (int) children.size(); i++)
children[i] = substituteSimplerExpression(node.getChildren()[i]); children[i] = substituteSimplerExpression(node.getChildren()[i]);
// Collect some info on constant expressions in children
bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant?
bool second_const = children.size() > 1 && isConstant(children[1]); ; // is second child constant?
double first, second; // if yes, value of first and second child
if (first_const)
first = getConstantValue(children[0]);
if (second_const)
second = getConstantValue(children[1]);
switch (node.getOperation().getId()) { switch (node.getOperation().getId()) {
case Operation::ADD: case Operation::ADD:
{ {
double first = getConstantValue(children[0]); if (first_const) {
double second = getConstantValue(children[1]); if (first == 0.0) { // Add 0
if (first == 0.0) // Add 0 return children[1];
return children[1]; } else { // Add a constant
if (second == 0.0) // Add 0 return ExpressionTreeNode(new Operation::AddConstant(first), children[1]);
return children[0]; }
if (first == first) // Add a constant }
return ExpressionTreeNode(new Operation::AddConstant(first), children[1]); if (second_const) {
if (second == second) // Add a constant if (second == 0.0) { // Add 0
return ExpressionTreeNode(new Operation::AddConstant(second), children[0]); return children[0];
} else { // Add a constant
return ExpressionTreeNode(new Operation::AddConstant(second), children[0]);
}
}
if (children[1].getOperation().getId() == Operation::NEGATE) // a+(-b) = a-b if (children[1].getOperation().getId() == Operation::NEGATE) // a+(-b) = a-b
return ExpressionTreeNode(new Operation::Subtract(), children[0], children[1].getChildren()[0]); return ExpressionTreeNode(new Operation::Subtract(), children[0], children[1].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::NEGATE) // (-a)+b = b-a if (children[0].getOperation().getId() == Operation::NEGATE) // (-a)+b = b-a
...@@ -144,34 +158,35 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -144,34 +158,35 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{ {
if (children[0] == children[1]) if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Constant(0.0)); // Subtracting anything from itself is 0 return ExpressionTreeNode(new Operation::Constant(0.0)); // Subtracting anything from itself is 0
double first = getConstantValue(children[0]); if (first_const) {
if (first == 0.0) // Subtract from 0 if (first == 0.0) // Subtract from 0
return ExpressionTreeNode(new Operation::Negate(), children[1]); return ExpressionTreeNode(new Operation::Negate(), children[1]);
double second = getConstantValue(children[1]); }
if (second == 0.0) // Subtract 0 if (second_const) {
return children[0]; if (second == 0.0) { // Subtract 0
if (second == second) // Subtract a constant return children[0];
return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]); } else { // Subtract a constant
return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]);
}
}
if (children[1].getOperation().getId() == Operation::NEGATE) // a-(-b) = a+b if (children[1].getOperation().getId() == Operation::NEGATE) // a-(-b) = a+b
return ExpressionTreeNode(new Operation::Add(), children[0], children[1].getChildren()[0]); return ExpressionTreeNode(new Operation::Add(), children[0], children[1].getChildren()[0]);
break; break;
} }
case Operation::MULTIPLY: case Operation::MULTIPLY:
{ {
double first = getConstantValue(children[0]); if ((first_const && first == 0.0) || (second_const && second == 0.0)) // Multiply by 0
double second = getConstantValue(children[1]);
if (first == 0.0 || second == 0.0) // Multiply by 0
return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Constant(0.0));
if (first == 1.0) // Multiply by 1 if (first_const && first == 1.0) // Multiply by 1
return children[1]; return children[1];
if (second == 1.0) // Multiply by 1 if (second_const && second == 1.0) // Multiply by 1
return children[0]; return children[0];
if (children[0].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant if (first_const) { // Multiply by a constant
if (children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one if (children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(first*dynamic_cast<const Operation::MultiplyConstant*>(&children[1].getOperation())->getValue()), children[1].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(first*dynamic_cast<const Operation::MultiplyConstant*>(&children[1].getOperation())->getValue()), children[1].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(first), children[1]); return ExpressionTreeNode(new Operation::MultiplyConstant(first), children[1]);
} }
if (children[1].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant if (second_const) { // Multiply by a constant
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(second*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(second*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(second), children[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(second), children[0]);
...@@ -202,18 +217,16 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -202,18 +217,16 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{ {
if (children[0] == children[1]) if (children[0] == children[1])
return ExpressionTreeNode(new Operation::Constant(1.0)); // Dividing anything from itself is 0 return ExpressionTreeNode(new Operation::Constant(1.0)); // Dividing anything from itself is 0
double numerator = getConstantValue(children[0]); if (first_const && first == 0.0) // 0 divided by something
if (numerator == 0.0) // 0 divided by something
return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Constant(0.0));
if (numerator == 1.0) // 1 divided by something if (first_const && first == 1.0) // 1 divided by something
return ExpressionTreeNode(new Operation::Reciprocal(), children[1]); return ExpressionTreeNode(new Operation::Reciprocal(), children[1]);
double denominator = getConstantValue(children[1]); if (second_const && second == 1.0) // Divide by 1
if (denominator == 1.0) // Divide by 1
return children[0]; return children[0];
if (children[1].getOperation().getId() == Operation::CONSTANT) { if (second_const) {
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a divide into one multiply if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a divide into one multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()/denominator), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()/second), children[0].getChildren()[0]);
return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/denominator), children[0]); // Replace a divide with a multiply return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/second), children[0]); // Replace a divide with a multiply
} }
if (children[0].getOperation().getId() == Operation::NEGATE && children[1].getOperation().getId() == Operation::NEGATE) // The two negations cancel if (children[0].getOperation().getId() == Operation::NEGATE && children[1].getOperation().getId() == Operation::NEGATE) // The two negations cancel
return ExpressionTreeNode(new Operation::Divide(), children[0].getChildren()[0], children[1].getChildren()[0]); return ExpressionTreeNode(new Operation::Divide(), children[0].getChildren()[0], children[1].getChildren()[0]);
...@@ -229,34 +242,34 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -229,34 +242,34 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
} }
case Operation::POWER: case Operation::POWER:
{ {
double base = getConstantValue(children[0]); if (first_const && first == 0.0) // 0 to any power is 0
if (base == 0.0) // 0 to any power is 0
return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Constant(0.0));
if (base == 1.0) // 1 to any power is 1 if (first_const && first == 1.0) // 1 to any power is 1
return ExpressionTreeNode(new Operation::Constant(1.0));
double exponent = getConstantValue(children[1]);
if (exponent == 0.0) // x^0 = 1
return ExpressionTreeNode(new Operation::Constant(1.0)); return ExpressionTreeNode(new Operation::Constant(1.0));
if (exponent == 1.0) // x^1 = x if (second_const) { // Constant exponent
return children[0]; if (second == 0.0) // x^0 = 1
if (exponent == -1.0) // x^-1 = recip(x) return ExpressionTreeNode(new Operation::Constant(1.0));
return ExpressionTreeNode(new Operation::Reciprocal(), children[0]); if (second == 1.0) // x^1 = x
if (exponent == 2.0) // x^2 = square(x) return children[0];
return ExpressionTreeNode(new Operation::Square(), children[0]); if (second == -1.0) // x^-1 = recip(x)
if (exponent == 3.0) // x^3 = cube(x) return ExpressionTreeNode(new Operation::Reciprocal(), children[0]);
return ExpressionTreeNode(new Operation::Cube(), children[0]); if (second == 2.0) // x^2 = square(x)
if (exponent == 0.5) // x^0.5 = sqrt(x) return ExpressionTreeNode(new Operation::Square(), children[0]);
return ExpressionTreeNode(new Operation::Sqrt(), children[0]); if (second == 3.0) // x^3 = cube(x)
if (exponent == exponent) // Constant power return ExpressionTreeNode(new Operation::Cube(), children[0]);
return ExpressionTreeNode(new Operation::PowerConstant(exponent), children[0]); if (second == 0.5) // x^0.5 = sqrt(x)
return ExpressionTreeNode(new Operation::Sqrt(), children[0]);
// Constant power
return ExpressionTreeNode(new Operation::PowerConstant(second), children[0]);
}
break; break;
} }
case Operation::NEGATE: case Operation::NEGATE:
{ {
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a negate into a single multiply if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a negate into a single multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::CONSTANT) // Negate a constant if (first_const) // Negate a constant
return ExpressionTreeNode(new Operation::Constant(-getConstantValue(children[0]))); return ExpressionTreeNode(new Operation::Constant(-first));
if (children[0].getOperation().getId() == Operation::NEGATE) // The two negations cancel if (children[0].getOperation().getId() == Operation::NEGATE) // The two negations cancel
return children[0].getChildren()[0]; return children[0].getChildren()[0];
break; break;
...@@ -265,7 +278,7 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -265,7 +278,7 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
{ {
if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one
return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*dynamic_cast<const Operation::MultiplyConstant*>(&children[0].getOperation())->getValue()), children[0].getChildren()[0]);
if (children[0].getOperation().getId() == Operation::CONSTANT) // Multiply two constants if (first_const) // Multiply two constants
return ExpressionTreeNode(new Operation::Constant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*getConstantValue(children[0]))); return ExpressionTreeNode(new Operation::Constant(dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()*getConstantValue(children[0])));
if (children[0].getOperation().getId() == Operation::NEGATE) // Combine a multiply and a negate into a single multiply if (children[0].getOperation().getId() == Operation::NEGATE) // Combine a multiply and a negate into a single multiply
return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast<const Operation::MultiplyConstant*>(&node.getOperation())->getValue()), children[0].getChildren()[0]);
...@@ -303,10 +316,15 @@ ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& nod ...@@ -303,10 +316,15 @@ ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& nod
return node.getOperation().differentiate(node.getChildren(),childDerivs, variable); return node.getOperation().differentiate(node.getChildren(),childDerivs, variable);
} }
bool ParsedExpression::isConstant(const ExpressionTreeNode& node) {
return (node.getOperation().getId() == Operation::CONSTANT);
}
double ParsedExpression::getConstantValue(const ExpressionTreeNode& node) { double ParsedExpression::getConstantValue(const ExpressionTreeNode& node) {
if (node.getOperation().getId() == Operation::CONSTANT) if (node.getOperation().getId() != Operation::CONSTANT) {
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue(); throw Exception("getConstantValue called on a non-constant ExpressionNode");
return numeric_limits<double>::quiet_NaN(); }
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue();
} }
ExpressionProgram ParsedExpression::createProgram() const { ExpressionProgram ParsedExpression::createProgram() const {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment