/* -------------------------------------------------------------------------- * * Lepton * * -------------------------------------------------------------------------- * * This is part of the Lepton expression parser originating from * * Simbios, the NIH National Center for Physics-Based Simulation of * * Biological Structures at Stanford, funded under the NIH Roadmap for * * Medical Research, grant U54 GM072970. See https://simtk.org. * * * * Portions copyright (c) 2009 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * * Permission is hereby granted, free of charge, to any person obtaining a * * copy of this software and associated documentation files (the "Software"), * * to deal in the Software without restriction, including without limitation * * the rights to use, copy, modify, merge, publish, distribute, sublicense, * * and/or sell copies of the Software, and to permit persons to whom the * * Software is furnished to do so, subject to the following conditions: * * * * The above copyright notice and this permission notice shall be included in * * all copies or substantial portions of the Software. * * * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, * * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR * * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE * * USE OR OTHER DEALINGS IN THE SOFTWARE. * * -------------------------------------------------------------------------- */ #include "lepton/ParsedExpression.h" #include "lepton/ExpressionProgram.h" #include "lepton/Operation.h" #include #include using namespace Lepton; using namespace std; ParsedExpression::ParsedExpression(const ExpressionTreeNode& rootNode) : rootNode(rootNode) { } const ExpressionTreeNode& ParsedExpression::getRootNode() const { return rootNode; } double ParsedExpression::evaluate() const { return evaluate(getRootNode(), map()); } double ParsedExpression::evaluate(const std::map& variables) const { return evaluate(getRootNode(), variables); } double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map& variables) { int numArgs = node.getChildren().size(); vector args(max(numArgs, 1)); for (int i = 0; i < numArgs; i++) args[i] = evaluate(node.getChildren()[i], variables); return node.getOperation().evaluate(&args[0], variables); } ParsedExpression ParsedExpression::optimize() const { ExpressionTreeNode result = precalculateConstantSubexpressions(getRootNode()); result = substituteSimplerExpression(result); result = substituteSimplerExpression(result); return ParsedExpression(result); } ParsedExpression ParsedExpression::optimize(const map& variables) const { ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables); result = precalculateConstantSubexpressions(result); result = substituteSimplerExpression(result); result = substituteSimplerExpression(result); return ParsedExpression(result); } ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNode& node, const map& variables) { if (node.getOperation().getId() == Operation::VARIABLE) { const Operation::Variable& var = dynamic_cast(node.getOperation()); map::const_iterator iter = variables.find(var.getName()); if (iter == variables.end()) return node; return ExpressionTreeNode(new Operation::Constant(iter->second)); } vector children(node.getChildren().size()); for (int i = 0; i < (int) children.size(); i++) children[i] = preevaluateVariables(node.getChildren()[i], variables); return ExpressionTreeNode(node.getOperation().clone(), children); } ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node) { vector children(node.getChildren().size()); for (int i = 0; i < (int) children.size(); i++) children[i] = precalculateConstantSubexpressions(node.getChildren()[i]); ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children); if (node.getOperation().getId() == Operation::VARIABLE) return result; for (int i = 0; i < (int) children.size(); i++) if (children[i].getOperation().getId() != Operation::CONSTANT) return result; return ExpressionTreeNode(new Operation::Constant(evaluate(result, map()))); } ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node) { vector children(node.getChildren().size()); for (int i = 0; i < (int) children.size(); i++) children[i] = substituteSimplerExpression(node.getChildren()[i]); switch (node.getOperation().getId()) { case Operation::ADD: { double first = getConstantValue(children[0]); double second = getConstantValue(children[1]); if (first == 0.0) // Add 0 return children[1]; if (second == 0.0) // Add 0 return children[0]; if (first == first) // Add a constant return ExpressionTreeNode(new Operation::AddConstant(first), children[1]); if (second == second) // Add a constant return ExpressionTreeNode(new Operation::AddConstant(second), children[0]); break; } case Operation::SUBTRACT: { double first = getConstantValue(children[0]); if (first == 0.0) // Subtract from 0 return ExpressionTreeNode(new Operation::Negate(), children[1]); double second = getConstantValue(children[1]); if (second == 0.0) // Subtract 0 return children[0]; if (second == second) // Subtract a constant return ExpressionTreeNode(new Operation::AddConstant(-second), children[0]); break; } case Operation::MULTIPLY: { double first = getConstantValue(children[0]); double second = getConstantValue(children[1]); if (first == 0.0 || second == 0.0) // Multiply by 0 return ExpressionTreeNode(new Operation::Constant(0.0)); if (first == 1.0) // Multiply by 1 return children[1]; if (second == 1.0) // Multiply by 1 return children[0]; if (children[0].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant if (children[1].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one return ExpressionTreeNode(new Operation::MultiplyConstant(first*dynamic_cast(&children[1].getOperation())->getValue()), children[1].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(first), children[1]); } if (children[1].getOperation().getId() == Operation::CONSTANT) { // Multiply by a constant if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one return ExpressionTreeNode(new Operation::MultiplyConstant(second*dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(second), children[0]); } break; } case Operation::DIVIDE: { double numerator = getConstantValue(children[0]); if (numerator == 0.0) // 0 divided by something return ExpressionTreeNode(new Operation::Constant(0.0)); if (numerator == 1.0) // 1 divided by something return ExpressionTreeNode(new Operation::Reciprocal(), children[1]); double denominator = getConstantValue(children[1]); if (denominator == 1.0) // Divide by 1 return children[0]; if (children[1].getOperation().getId() == Operation::CONSTANT) { if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a divide into one multiply return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast(&children[0].getOperation())->getValue()/denominator), children[0].getChildren()[0]); return ExpressionTreeNode(new Operation::MultiplyConstant(1.0/denominator), children[0]); // Replace a divide with a multiply } break; } case Operation::POWER: { double base = getConstantValue(children[0]); if (base == 0.0) // 0 to any power is 0 return ExpressionTreeNode(new Operation::Constant(0.0)); if (base == 1.0) // 1 to any power is 1 return ExpressionTreeNode(new Operation::Constant(1.0)); double exponent = getConstantValue(children[1]); if (exponent == 0.0) // x^0 = 1 return ExpressionTreeNode(new Operation::Constant(1.0)); if (exponent == 1.0) // x^1 = x return children[0]; if (exponent == -1.0) // x^-1 = recip(x) return ExpressionTreeNode(new Operation::Reciprocal(), children[0]); if (exponent == 2.0) // x^2 = square(x) return ExpressionTreeNode(new Operation::Square(), children[0]); if (exponent == 3.0) // x^3 = cube(x) return ExpressionTreeNode(new Operation::Cube(), children[0]); if (exponent == 0.5) // x^0.5 = sqrt(x) return ExpressionTreeNode(new Operation::Sqrt(), children[0]); if (exponent == exponent) // Constant power return ExpressionTreeNode(new Operation::PowerConstant(exponent), children[0]); break; } case Operation::NEGATE: { if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine a multiply and a negate into a single multiply return ExpressionTreeNode(new Operation::MultiplyConstant(-dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); break; } case Operation::MULTIPLY_CONSTANT: { if (children[0].getOperation().getId() == Operation::MULTIPLY_CONSTANT) // Combine two multiplies into a single one return ExpressionTreeNode(new Operation::MultiplyConstant(dynamic_cast(&node.getOperation())->getValue()*dynamic_cast(&children[0].getOperation())->getValue()), children[0].getChildren()[0]); } } return ExpressionTreeNode(node.getOperation().clone(), children); } ParsedExpression ParsedExpression::differentiate(const std::string& variable) const { return differentiate(getRootNode(), variable); } ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const std::string& variable) { vector childDerivs(node.getChildren().size()); for (int i = 0; i < (int) childDerivs.size(); i++) childDerivs[i] = differentiate(node.getChildren()[i], variable); return node.getOperation().differentiate(node.getChildren(),childDerivs, variable); } double ParsedExpression::getConstantValue(const ExpressionTreeNode& node) { if (node.getOperation().getId() == Operation::CONSTANT) return dynamic_cast(node.getOperation()).getValue(); return numeric_limits::quiet_NaN(); } ExpressionProgram ParsedExpression::createProgram() const { return ExpressionProgram(*this); } ostream& Lepton::operator<<(ostream& out, const ExpressionTreeNode& node) { out << node.getOperation().getName(); if (node.getChildren().size() > 0) { out << "("; for (int i = 0; i < (int) node.getChildren().size(); i++) { if (i > 0) out << ", "; out << node.getChildren()[i]; } out << ")"; } return out; } ostream& Lepton::operator<<(ostream& out, const ParsedExpression& exp) { out << exp.getRootNode(); return out; }