Commit 4aabb908 authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented differentiation. Also added lots more optimizations.

parent 6f67100a
...@@ -41,6 +41,8 @@ ...@@ -41,6 +41,8 @@
namespace Lepton { namespace Lepton {
class ExpressionTreeNode;
/** /**
* An Operation represents a single step in the evaluation of an expression, such as a function, * An Operation represents a single step in the evaluation of an expression, such as a function,
* an operator, or a constant value. Each Operation takes some number of values as arguments * an operator, or a constant value. Each Operation takes some number of values as arguments
...@@ -56,7 +58,7 @@ public: ...@@ -56,7 +58,7 @@ public:
* can be used when processing or analyzing parsed expressions. * can be used when processing or analyzing parsed expressions.
*/ */
enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG, enum Id {CONSTANT, VARIABLE, CUSTOM, ADD, SUBTRACT, MULTIPLY, DIVIDE, POWER, NEGATE, SQRT, EXP, LOG,
SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SQUARE, CUBE, RECIPROCAL}; SIN, COS, SEC, CSC, TAN, COT, ASIN, ACOS, ATAN, SQUARE, CUBE, RECIPROCAL, INCREMENT, DECREMENT};
/** /**
* Get the name of this Operation. * Get the name of this Operation.
*/ */
...@@ -81,6 +83,14 @@ public: ...@@ -81,6 +83,14 @@ public:
* @return the result of performing the computation. * @return the result of performing the computation.
*/ */
virtual double evaluate(double* args, const std::map<std::string, double>& variables) const = 0; virtual double evaluate(double* args, const std::map<std::string, double>& variables) const = 0;
/**
* Return an ExpressionTreeNode which represents the analytic derivative of this Operation with respect to a variable.
*
* @param children the child nodes
* @param childDerivs the derivatives of the child nodes with respect to the variable
* @param variable the variable with respect to which the derivate should be taken
*/
virtual ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const = 0;
class Constant; class Constant;
class Variable; class Variable;
class Custom; class Custom;
...@@ -105,6 +115,8 @@ public: ...@@ -105,6 +115,8 @@ public:
class Square; class Square;
class Cube; class Cube;
class Reciprocal; class Reciprocal;
class Increment;
class Decrement;
}; };
class Operation::Constant : public Operation { class Operation::Constant : public Operation {
...@@ -128,6 +140,7 @@ public: ...@@ -128,6 +140,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return value; return value;
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
double getValue() const { double getValue() const {
return value; return value;
} }
...@@ -157,6 +170,7 @@ public: ...@@ -157,6 +170,7 @@ public:
throw std::exception(); throw std::exception();
return iter->second; return iter->second;
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
private: private:
std::string name; std::string name;
}; };
...@@ -180,6 +194,7 @@ public: ...@@ -180,6 +194,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return 0.0; return 0.0;
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
private: private:
std::string name; std::string name;
int arguments; int arguments;
...@@ -204,6 +219,7 @@ public: ...@@ -204,6 +219,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return args[0]+args[1]; return args[0]+args[1];
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Subtract : public Operation { class Operation::Subtract : public Operation {
...@@ -225,6 +241,7 @@ public: ...@@ -225,6 +241,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return args[0]-args[1]; return args[0]-args[1];
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Multiply : public Operation { class Operation::Multiply : public Operation {
...@@ -246,6 +263,7 @@ public: ...@@ -246,6 +263,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return args[0]*args[1]; return args[0]*args[1];
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Divide : public Operation { class Operation::Divide : public Operation {
...@@ -267,6 +285,7 @@ public: ...@@ -267,6 +285,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return args[0]/args[1]; return args[0]/args[1];
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Power : public Operation { class Operation::Power : public Operation {
...@@ -288,6 +307,7 @@ public: ...@@ -288,6 +307,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::pow(args[0], args[1]); return std::pow(args[0], args[1]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Negate : public Operation { class Operation::Negate : public Operation {
...@@ -309,6 +329,7 @@ public: ...@@ -309,6 +329,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return -args[0]; return -args[0];
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Sqrt : public Operation { class Operation::Sqrt : public Operation {
...@@ -330,6 +351,7 @@ public: ...@@ -330,6 +351,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::sqrt(args[0]); return std::sqrt(args[0]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Exp : public Operation { class Operation::Exp : public Operation {
...@@ -351,6 +373,7 @@ public: ...@@ -351,6 +373,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::exp(args[0]); return std::exp(args[0]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Log : public Operation { class Operation::Log : public Operation {
...@@ -372,6 +395,7 @@ public: ...@@ -372,6 +395,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::log(args[0]); return std::log(args[0]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Sin : public Operation { class Operation::Sin : public Operation {
...@@ -393,6 +417,7 @@ public: ...@@ -393,6 +417,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::sin(args[0]); return std::sin(args[0]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Cos : public Operation { class Operation::Cos : public Operation {
...@@ -414,6 +439,7 @@ public: ...@@ -414,6 +439,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::cos(args[0]); return std::cos(args[0]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Sec : public Operation { class Operation::Sec : public Operation {
...@@ -435,6 +461,7 @@ public: ...@@ -435,6 +461,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return 1.0/std::cos(args[0]); return 1.0/std::cos(args[0]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Csc : public Operation { class Operation::Csc : public Operation {
...@@ -456,6 +483,7 @@ public: ...@@ -456,6 +483,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return 1.0/std::sin(args[0]); return 1.0/std::sin(args[0]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Tan : public Operation { class Operation::Tan : public Operation {
...@@ -477,6 +505,7 @@ public: ...@@ -477,6 +505,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::tan(args[0]); return std::tan(args[0]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Cot : public Operation { class Operation::Cot : public Operation {
...@@ -498,6 +527,7 @@ public: ...@@ -498,6 +527,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return 1.0/std::tan(args[0]); return 1.0/std::tan(args[0]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Asin : public Operation { class Operation::Asin : public Operation {
...@@ -519,6 +549,7 @@ public: ...@@ -519,6 +549,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::asin(args[0]); return std::asin(args[0]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Acos : public Operation { class Operation::Acos : public Operation {
...@@ -540,6 +571,7 @@ public: ...@@ -540,6 +571,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::acos(args[0]); return std::acos(args[0]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Atan : public Operation { class Operation::Atan : public Operation {
...@@ -561,6 +593,7 @@ public: ...@@ -561,6 +593,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return std::atan(args[0]); return std::atan(args[0]);
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Square : public Operation { class Operation::Square : public Operation {
...@@ -582,6 +615,7 @@ public: ...@@ -582,6 +615,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return args[0]*args[0]; return args[0]*args[0];
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Cube : public Operation { class Operation::Cube : public Operation {
...@@ -603,6 +637,7 @@ public: ...@@ -603,6 +637,7 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return args[0]*args[0]*args[0]; return args[0]*args[0]*args[0];
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
class Operation::Reciprocal : public Operation { class Operation::Reciprocal : public Operation {
...@@ -624,6 +659,51 @@ public: ...@@ -624,6 +659,51 @@ public:
double evaluate(double* args, const std::map<std::string, double>& variables) const { double evaluate(double* args, const std::map<std::string, double>& variables) const {
return 1.0/args[0]; return 1.0/args[0];
} }
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class Operation::Increment : public Operation {
public:
Increment() {
}
std::string getName() const {
return "increment";
}
Id getId() const {
return INCREMENT;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Increment();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return args[0]+1.0;
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
};
class Operation::Decrement : public Operation {
public:
Decrement() {
}
std::string getName() const {
return "decrement";
}
Id getId() const {
return DECREMENT;
}
int getNumArguments() const {
return 1;
}
Operation* clone() const {
return new Decrement();
}
double evaluate(double* args, const std::map<std::string, double>& variables) const {
return args[0]-1.0;
}
ExpressionTreeNode differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const;
}; };
} // namespace Lepton } // namespace Lepton
......
...@@ -79,14 +79,27 @@ public: ...@@ -79,14 +79,27 @@ public:
* specified. * specified.
*/ */
ParsedExpression optimize(const std::map<std::string, double>& variables) const; ParsedExpression optimize(const std::map<std::string, double>& variables) const;
/**
* Create a new ParsedExpression which is the analytic derivative of this expression with respect to a
* particular variable.
*
* @param variable the variable with respect to which the derivate should be taken
*/
ParsedExpression differentiate(const std::string& variable) const;
private: private:
static double evaluate(const ExpressionTreeNode& node, const std::map<std::string, double>& variables); static double evaluate(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map<std::string, double>& variables); static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node); static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node);
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node); static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node);
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable);
static double getConstantValue(const ExpressionTreeNode& node);
ExpressionTreeNode rootNode; ExpressionTreeNode rootNode;
}; };
std::ostream& operator<<(std::ostream& out, const ExpressionTreeNode& node);
std::ostream& operator<<(std::ostream& out, const ParsedExpression& exp);
} // namespace Lepton } // namespace Lepton
#endif /*LEPTON_PARSED_EXPRESSION_H_*/ #endif /*LEPTON_PARSED_EXPRESSION_H_*/
/* -------------------------------------------------------------------------- *
* Lepton *
* -------------------------------------------------------------------------- *
* This is part of the Lepton expression parser originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "Operation.h"
#include "ExpressionTreeNode.h"
using namespace Lepton;
using namespace std;
ExpressionTreeNode Operation::Constant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Constant(0.0));
}
ExpressionTreeNode Operation::Variable::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (variable == name)
return ExpressionTreeNode(new Operation::Constant(1.0));
return ExpressionTreeNode(new Operation::Constant(0.0));
}
ExpressionTreeNode Operation::Custom::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Constant(0.0));
}
ExpressionTreeNode Operation::Add::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Add(), childDerivs[0], childDerivs[1]);
}
ExpressionTreeNode Operation::Subtract::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Subtract(), childDerivs[0], childDerivs[1]);
}
ExpressionTreeNode Operation::Multiply::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Add(),
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]),
ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]));
}
ExpressionTreeNode Operation::Divide::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Divide(),
ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]),
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])),
ExpressionTreeNode(new Operation::Square(), children[1]));
}
ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Add(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
children[1],
ExpressionTreeNode(new Operation::Power(),
children[0], ExpressionTreeNode(new Operation::Decrement(), children[1]))),
childDerivs[0]),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Log(), children[0]),
ExpressionTreeNode(new Operation::Power(), children[0], children[1])),
childDerivs[1]));
}
ExpressionTreeNode Operation::Negate::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Negate(), childDerivs[0]);
}
ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Constant(0.5)),
ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::Sqrt(), children[0]))),
childDerivs[0]);
}
ExpressionTreeNode Operation::Exp::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Exp(), children[0]),
childDerivs[0]);
}
ExpressionTreeNode Operation::Log::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(), children[0]),
childDerivs[0]);
}
ExpressionTreeNode Operation::Sin::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Cos(), children[0]),
childDerivs[0]);
}
ExpressionTreeNode Operation::Cos::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Sin(), children[0])),
childDerivs[0]);
}
ExpressionTreeNode Operation::Sec::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Sec(), children[0]),
ExpressionTreeNode(new Operation::Tan(), children[0])),
childDerivs[0]);
}
ExpressionTreeNode Operation::Csc::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Csc(), children[0]),
ExpressionTreeNode(new Operation::Cot(), children[0]))),
childDerivs[0]);
}
ExpressionTreeNode Operation::Tan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Square(),
ExpressionTreeNode(new Operation::Sec(), children[0])),
childDerivs[0]);
}
ExpressionTreeNode Operation::Cot::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Square(),
ExpressionTreeNode(new Operation::Csc(), children[0]))),
childDerivs[0]);
}
ExpressionTreeNode Operation::Asin::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::Sqrt(),
ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Constant(1.0)),
ExpressionTreeNode(new Operation::Square(), children[0])))),
childDerivs[0]);
}
ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::Sqrt(),
ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Constant(1.0)),
ExpressionTreeNode(new Operation::Square(), children[0]))))),
childDerivs[0]);
}
ExpressionTreeNode Operation::Atan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::Increment(),
ExpressionTreeNode(new Operation::Square(), children[0]))),
childDerivs[0]);
}
ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Constant(2.0)),
children[0]),
childDerivs[0]);
}
ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Constant(3.0)),
ExpressionTreeNode(new Operation::Square(), children[0])),
childDerivs[0]);
}
ExpressionTreeNode Operation::Reciprocal::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::Square(), children[0]))),
childDerivs[0]);
}
ExpressionTreeNode Operation::Increment::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return childDerivs[0];
}
ExpressionTreeNode Operation::Decrement::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return childDerivs[0];
}
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "ParsedExpression.h" #include "ParsedExpression.h"
#include "Operation.h" #include "Operation.h"
#include <limits>
#include <vector> #include <vector>
using namespace Lepton; using namespace Lepton;
...@@ -59,16 +60,16 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map<stri ...@@ -59,16 +60,16 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map<stri
} }
ParsedExpression ParsedExpression::optimize() const { ParsedExpression ParsedExpression::optimize() const {
ParsedExpression result = precalculateConstantSubexpressions(getRootNode()); ExpressionTreeNode result = precalculateConstantSubexpressions(getRootNode());
result = substituteSimplerExpression(result.getRootNode()); result = substituteSimplerExpression(result);
return result; return result;
} }
ParsedExpression ParsedExpression::optimize(const map<string, double>& variables) const { ParsedExpression ParsedExpression::optimize(const map<string, double>& variables) const {
ParsedExpression result = preevaluateVariables(getRootNode(), variables); ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables);
result = precalculateConstantSubexpressions(result.getRootNode()); result = precalculateConstantSubexpressions(result);
result = substituteSimplerExpression(result.getRootNode()); result = substituteSimplerExpression(result);
return result; return ParsedExpression(result);
} }
ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNode& node, const map<string, double>& variables) { ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNode& node, const map<string, double>& variables) {
...@@ -103,27 +104,102 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -103,27 +104,102 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
for (int i = 0; i < children.size(); i++) for (int i = 0; i < children.size(); i++)
children[i] = substituteSimplerExpression(node.getChildren()[i]); children[i] = substituteSimplerExpression(node.getChildren()[i]);
switch (node.getOperation().getId()) { switch (node.getOperation().getId()) {
case Operation::ADD:
double first = getConstantValue(children[0]);
double second = getConstantValue(children[1]);
if (first == 0.0)
return children[1];
if (first == 1.0)
return ExpressionTreeNode(new Operation::Increment(), children[1]);
if (second == 0.0)
return children[0];
if (second == 1.0)
return ExpressionTreeNode(new Operation::Increment(), children[0]);
break;
case Operation::SUBTRACT:
first = getConstantValue(children[0]);
if (first == 0.0)
return ExpressionTreeNode(new Operation::Negate(), children[1]);
second = getConstantValue(children[1]);
if (second == 0.0)
return children[0];
if (second == 1.0)
return ExpressionTreeNode(new Operation::Decrement(), children[0]);
break;
case Operation::MULTIPLY:
first = getConstantValue(children[0]);
second = getConstantValue(children[1]);
if (first == 0.0 || second == 0.0)
return ExpressionTreeNode(new Operation::Constant(0.0));
if (first == 1.0)
return children[1];
if (second == 1.0)
return children[0];
break;
case Operation::DIVIDE: case Operation::DIVIDE:
if (children[0].getOperation().getId() == Operation::CONSTANT) { double numerator = getConstantValue(children[0]);
if (dynamic_cast<const Operation::Constant&>(children[0].getOperation()).getValue() == 1.0) if (numerator = 0.0)
return ExpressionTreeNode(new Operation::Reciprocal(), children[1]); return ExpressionTreeNode(new Operation::Constant(0.0));
} if (numerator == 1.0)
return ExpressionTreeNode(new Operation::Reciprocal(), children[1]);
double denominator = getConstantValue(children[1]);
if (denominator == 1.0)
return children[0];
break; break;
case Operation::POWER: case Operation::POWER:
if (children[1].getOperation().getId() == Operation::CONSTANT) { double base = getConstantValue(children[0]);
double exponent = dynamic_cast<const Operation::Constant&>(children[1].getOperation()).getValue(); if (base == 0.0)
if (exponent == 1.0) return ExpressionTreeNode(new Operation::Constant(0.0));
return children[0]; if (base == 1.0)
if (exponent == -1.0) return ExpressionTreeNode(new Operation::Constant(1.0));
return ExpressionTreeNode(new Operation::Reciprocal(), children[0]); double exponent = getConstantValue(children[1]);
if (exponent == 2.0) if (exponent == 1.0)
return ExpressionTreeNode(new Operation::Square(), children[0]); return children[0];
if (exponent == 3.0) if (exponent == -1.0)
return ExpressionTreeNode(new Operation::Cube(), children[0]); return ExpressionTreeNode(new Operation::Reciprocal(), children[0]);
if (exponent == 0.5) if (exponent == 2.0)
return ExpressionTreeNode(new Operation::Sqrt(), children[0]); return ExpressionTreeNode(new Operation::Square(), children[0]);
} if (exponent == 3.0)
return ExpressionTreeNode(new Operation::Cube(), children[0]);
if (exponent == 0.5)
return ExpressionTreeNode(new Operation::Sqrt(), children[0]);
break; break;
} }
return ExpressionTreeNode(node.getOperation().clone(), children); return ExpressionTreeNode(node.getOperation().clone(), children);
} }
ParsedExpression ParsedExpression::differentiate(const std::string& variable) const {
return differentiate(getRootNode(), variable);
}
ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const std::string& variable) {
vector<ExpressionTreeNode> childDerivs(node.getChildren().size());
for (int i = 0; i < childDerivs.size(); i++)
childDerivs[i] = differentiate(node.getChildren()[i], variable);
return node.getOperation().differentiate(node.getChildren(),childDerivs, variable);
}
double ParsedExpression::getConstantValue(const ExpressionTreeNode& node) {
if (node.getOperation().getId() == Operation::CONSTANT)
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue();
return numeric_limits<double>::quiet_NaN();
}
ostream& Lepton::operator<<(ostream& out, const ExpressionTreeNode& node) {
out << node.getOperation().getName();
if (node.getChildren().size() > 0) {
out << "(";
for (int i = 0; i < node.getChildren().size(); i++) {
if (i > 0)
out << ", ";
out << node.getChildren()[i];
}
out << ")";
}
return out;
}
ostream& Lepton::operator<<(ostream& out, const ParsedExpression& exp) {
out << exp.getRootNode();
return out;
}
...@@ -239,6 +239,8 @@ Operation* Parser::getFunctionOperation(const std::string& name, int arguments) ...@@ -239,6 +239,8 @@ Operation* Parser::getFunctionOperation(const std::string& name, int arguments)
opMap["square"] = Operation::SQUARE; opMap["square"] = Operation::SQUARE;
opMap["cube"] = Operation::CUBE; opMap["cube"] = Operation::CUBE;
opMap["recip"] = Operation::RECIPROCAL; opMap["recip"] = Operation::RECIPROCAL;
opMap["increment"] = Operation::INCREMENT;
opMap["decrement"] = Operation::DECREMENT;
} }
string trimmed = name.substr(0, name.size()-1); string trimmed = name.substr(0, name.size()-1);
map<string, Operation::Id>::const_iterator iter = opMap.find(trimmed); map<string, Operation::Id>::const_iterator iter = opMap.find(trimmed);
...@@ -275,6 +277,10 @@ Operation* Parser::getFunctionOperation(const std::string& name, int arguments) ...@@ -275,6 +277,10 @@ Operation* Parser::getFunctionOperation(const std::string& name, int arguments)
return new Operation::Cube(); return new Operation::Cube();
case Operation::RECIPROCAL: case Operation::RECIPROCAL:
return new Operation::Reciprocal(); return new Operation::Reciprocal();
case Operation::INCREMENT:
return new Operation::Increment();
case Operation::DECREMENT:
return new Operation::Decrement();
default: default:
throw Exception("Parse error: unknown function"); throw Exception("Parse error: unknown function");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment