"deployment/vscode:/vscode.git/clone" did not exist on "53ed465a56c96878f4ed8f7264391a30a23eab1e"
Unverified Commit ac754c45 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Optimizations to Lepton (#3044)

* Optimizations to Lepton

* More optimizations to Lepton
parent a85c2428
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009 Stanford University and the Authors. * * Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
namespace Lepton { namespace Lepton {
class Operation; class Operation;
class ParsedExpression;
/** /**
* This class represents a node in the abstract syntax tree representation of an expression. * This class represents a node in the abstract syntax tree representation of an expression.
...@@ -82,11 +83,13 @@ public: ...@@ -82,11 +83,13 @@ public:
*/ */
ExpressionTreeNode(Operation* operation); ExpressionTreeNode(Operation* operation);
ExpressionTreeNode(const ExpressionTreeNode& node); ExpressionTreeNode(const ExpressionTreeNode& node);
ExpressionTreeNode(ExpressionTreeNode&& node);
ExpressionTreeNode(); ExpressionTreeNode();
~ExpressionTreeNode(); ~ExpressionTreeNode();
bool operator==(const ExpressionTreeNode& node) const; bool operator==(const ExpressionTreeNode& node) const;
bool operator!=(const ExpressionTreeNode& node) const; bool operator!=(const ExpressionTreeNode& node) const;
ExpressionTreeNode& operator=(const ExpressionTreeNode& node); ExpressionTreeNode& operator=(const ExpressionTreeNode& node);
ExpressionTreeNode& operator=(ExpressionTreeNode&& node);
/** /**
* Get the Operation performed by this node. * Get the Operation performed by this node.
*/ */
...@@ -96,8 +99,11 @@ public: ...@@ -96,8 +99,11 @@ public:
*/ */
const std::vector<ExpressionTreeNode>& getChildren() const; const std::vector<ExpressionTreeNode>& getChildren() const;
private: private:
friend class ParsedExpression;
void assignTags(std::vector<const ExpressionTreeNode*>& examples) const;
Operation* operation; Operation* operation;
std::vector<ExpressionTreeNode> children; std::vector<ExpressionTreeNode> children;
mutable int tag;
}; };
} // namespace Lepton } // namespace Lepton
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009=2013 Stanford University and the Authors. * * Portions copyright (c) 2009=2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -113,9 +113,9 @@ public: ...@@ -113,9 +113,9 @@ public:
private: private:
static double evaluate(const ExpressionTreeNode& node, const std::map<std::string, double>& variables); static double evaluate(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map<std::string, double>& variables); static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node); static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node, std::map<int, ExpressionTreeNode>& nodeCache);
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node); static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node, std::map<int, ExpressionTreeNode>& nodeCache);
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable); static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable, std::map<int, ExpressionTreeNode>& nodeCache);
static bool isConstant(const ExpressionTreeNode& node); static bool isConstant(const ExpressionTreeNode& node);
static double getConstantValue(const ExpressionTreeNode& node); static double getConstantValue(const ExpressionTreeNode& node);
static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements); static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. * * Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "lepton/ExpressionTreeNode.h" #include "lepton/ExpressionTreeNode.h"
#include "lepton/Exception.h" #include "lepton/Exception.h"
#include "lepton/Operation.h" #include "lepton/Operation.h"
#include <utility>
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
...@@ -62,6 +63,11 @@ ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operati ...@@ -62,6 +63,11 @@ ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operati
ExpressionTreeNode::ExpressionTreeNode(const ExpressionTreeNode& node) : operation(node.operation == NULL ? NULL : node.operation->clone()), children(node.getChildren()) { ExpressionTreeNode::ExpressionTreeNode(const ExpressionTreeNode& node) : operation(node.operation == NULL ? NULL : node.operation->clone()), children(node.getChildren()) {
} }
ExpressionTreeNode::ExpressionTreeNode(ExpressionTreeNode&& node) : operation(node.operation), children(move(node.children)) {
node.operation = NULL;
node.children.clear();
}
ExpressionTreeNode::ExpressionTreeNode() : operation(NULL) { ExpressionTreeNode::ExpressionTreeNode() : operation(NULL) {
} }
...@@ -98,6 +104,16 @@ ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node ...@@ -98,6 +104,16 @@ ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node
return *this; return *this;
} }
ExpressionTreeNode& ExpressionTreeNode::operator=(ExpressionTreeNode&& node) {
if (operation != NULL)
delete operation;
operation = node.operation;
children = move(node.children);
node.operation = NULL;
node.children.clear();
return *this;
}
const Operation& ExpressionTreeNode::getOperation() const { const Operation& ExpressionTreeNode::getOperation() const {
return *operation; return *operation;
} }
...@@ -105,3 +121,33 @@ const Operation& ExpressionTreeNode::getOperation() const { ...@@ -105,3 +121,33 @@ const Operation& ExpressionTreeNode::getOperation() const {
const vector<ExpressionTreeNode>& ExpressionTreeNode::getChildren() const { const vector<ExpressionTreeNode>& ExpressionTreeNode::getChildren() const {
return children; return children;
} }
void ExpressionTreeNode::assignTags(vector<const ExpressionTreeNode*>& examples) const {
// Assign tag values to all nodes in a tree, such that two nodes have the same
// tag if and only if they (and all their children) are equal. This is used to
// optimize other operations.
int numTags = examples.size();
for (const ExpressionTreeNode& child : getChildren())
child.assignTags(examples);
if (numTags == examples.size()) {
// All the children matched existing tags, so possibly this node does too.
for (int i = 0; i < examples.size(); i++) {
const ExpressionTreeNode& example = *examples[i];
bool matches = (getChildren().size() == example.getChildren().size() && getOperation() == example.getOperation());
for (int j = 0; matches && j < getChildren().size(); j++)
if (getChildren()[j].tag != example.getChildren()[j].tag)
matches = false;
if (matches) {
tag = i;
return;
}
}
}
// This node does not match any previous node, so assign a new tag.
tag = examples.size();
examples.push_back(this);
}
...@@ -37,6 +37,12 @@ ...@@ -37,6 +37,12 @@
using namespace Lepton; using namespace Lepton;
using namespace std; using namespace std;
static bool isZero(const ExpressionTreeNode& node) {
if (node.getOperation().getId() != Operation::CONSTANT)
return false;
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue() == 0.0;
}
double Operation::Erf::evaluate(double* args, const map<string, double>& variables) const { double Operation::Erf::evaluate(double* args, const map<string, double>& variables) const {
return erf(args[0]); return erf(args[0]);
} }
...@@ -58,35 +64,71 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector<Expressi ...@@ -58,35 +64,71 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector<Expressi
ExpressionTreeNode Operation::Custom::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Custom::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (function->getNumArguments() == 0) if (function->getNumArguments() == 0)
return ExpressionTreeNode(new Operation::Constant(0.0)); return ExpressionTreeNode(new Operation::Constant(0.0));
ExpressionTreeNode result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, 0), children), childDerivs[0]); ExpressionTreeNode result;
for (int i = 1; i < getNumArguments(); i++) { bool foundTerm = false;
for (int i = 0; i < getNumArguments(); i++) {
if (!isZero(childDerivs[i])) {
if (foundTerm)
result = ExpressionTreeNode(new Operation::Add(), result = ExpressionTreeNode(new Operation::Add(),
result, result,
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i])); ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]));
else {
result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]);
foundTerm = true;
}
} }
}
if (foundTerm)
return result; return result;
return ExpressionTreeNode(new Operation::Constant(0.0));
} }
ExpressionTreeNode Operation::Add::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Add::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return childDerivs[1];
if (isZero(childDerivs[1]))
return childDerivs[0];
return ExpressionTreeNode(new Operation::Add(), childDerivs[0], childDerivs[1]); return ExpressionTreeNode(new Operation::Add(), childDerivs[0], childDerivs[1]);
} }
ExpressionTreeNode Operation::Subtract::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Subtract::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0])) {
if (isZero(childDerivs[1]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Negate(), childDerivs[1]);
}
if (isZero(childDerivs[1]))
return childDerivs[0];
return ExpressionTreeNode(new Operation::Subtract(), childDerivs[0], childDerivs[1]); return ExpressionTreeNode(new Operation::Subtract(), childDerivs[0], childDerivs[1]);
} }
ExpressionTreeNode Operation::Multiply::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Multiply::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0])) {
if (isZero(childDerivs[1]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]);
}
if (isZero(childDerivs[1]))
return ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]);
return ExpressionTreeNode(new Operation::Add(), return ExpressionTreeNode(new Operation::Add(),
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]),
ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0])); ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]));
} }
ExpressionTreeNode Operation::Divide::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Divide::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Divide(), ExpressionTreeNode subexp;
ExpressionTreeNode(new Operation::Subtract(), if (isZero(childDerivs[0])) {
if (isZero(childDerivs[1]))
return ExpressionTreeNode(new Operation::Constant(0.0));
subexp = ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]));
}
else if (isZero(childDerivs[1]))
subexp = ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]);
else
subexp = ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]), ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]),
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]));
ExpressionTreeNode(new Operation::Square(), children[1])); return ExpressionTreeNode(new Operation::Divide(), subexp, ExpressionTreeNode(new Operation::Square(), children[1]));
} }
ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
...@@ -105,10 +147,14 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionT ...@@ -105,10 +147,14 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionT
} }
ExpressionTreeNode Operation::Negate::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Negate::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Negate(), childDerivs[0]); return ExpressionTreeNode(new Operation::Negate(), childDerivs[0]);
} }
ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(0.5), ExpressionTreeNode(new Operation::MultiplyConstant(0.5),
ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::Reciprocal(),
...@@ -117,24 +163,32 @@ ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTr ...@@ -117,24 +163,32 @@ ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTr
} }
ExpressionTreeNode Operation::Exp::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Exp::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Exp(), children[0]), ExpressionTreeNode(new Operation::Exp(), children[0]),
childDerivs[0]); childDerivs[0]);
} }
ExpressionTreeNode Operation::Log::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Log::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(), children[0]), ExpressionTreeNode(new Operation::Reciprocal(), children[0]),
childDerivs[0]); childDerivs[0]);
} }
ExpressionTreeNode Operation::Sin::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Sin::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Cos(), children[0]), ExpressionTreeNode(new Operation::Cos(), children[0]),
childDerivs[0]); childDerivs[0]);
} }
ExpressionTreeNode Operation::Cos::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Cos::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Sin(), children[0])), ExpressionTreeNode(new Operation::Sin(), children[0])),
...@@ -142,6 +196,8 @@ ExpressionTreeNode Operation::Cos::differentiate(const std::vector<ExpressionTre ...@@ -142,6 +196,8 @@ ExpressionTreeNode Operation::Cos::differentiate(const std::vector<ExpressionTre
} }
ExpressionTreeNode Operation::Sec::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Sec::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Sec(), children[0]), ExpressionTreeNode(new Operation::Sec(), children[0]),
...@@ -150,6 +206,8 @@ ExpressionTreeNode Operation::Sec::differentiate(const std::vector<ExpressionTre ...@@ -150,6 +206,8 @@ ExpressionTreeNode Operation::Sec::differentiate(const std::vector<ExpressionTre
} }
ExpressionTreeNode Operation::Csc::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Csc::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(),
...@@ -159,6 +217,8 @@ ExpressionTreeNode Operation::Csc::differentiate(const std::vector<ExpressionTre ...@@ -159,6 +217,8 @@ ExpressionTreeNode Operation::Csc::differentiate(const std::vector<ExpressionTre
} }
ExpressionTreeNode Operation::Tan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Tan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Square(), ExpressionTreeNode(new Operation::Square(),
ExpressionTreeNode(new Operation::Sec(), children[0])), ExpressionTreeNode(new Operation::Sec(), children[0])),
...@@ -166,6 +226,8 @@ ExpressionTreeNode Operation::Tan::differentiate(const std::vector<ExpressionTre ...@@ -166,6 +226,8 @@ ExpressionTreeNode Operation::Tan::differentiate(const std::vector<ExpressionTre
} }
ExpressionTreeNode Operation::Cot::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Cot::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Square(), ExpressionTreeNode(new Operation::Square(),
...@@ -174,6 +236,8 @@ ExpressionTreeNode Operation::Cot::differentiate(const std::vector<ExpressionTre ...@@ -174,6 +236,8 @@ ExpressionTreeNode Operation::Cot::differentiate(const std::vector<ExpressionTre
} }
ExpressionTreeNode Operation::Asin::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Asin::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::Sqrt(), ExpressionTreeNode(new Operation::Sqrt(),
...@@ -184,6 +248,8 @@ ExpressionTreeNode Operation::Asin::differentiate(const std::vector<ExpressionTr ...@@ -184,6 +248,8 @@ ExpressionTreeNode Operation::Asin::differentiate(const std::vector<ExpressionTr
} }
ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::Reciprocal(),
...@@ -195,6 +261,8 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTr ...@@ -195,6 +261,8 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTr
} }
ExpressionTreeNode Operation::Atan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Atan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::AddConstant(1.0), ExpressionTreeNode(new Operation::AddConstant(1.0),
...@@ -213,6 +281,8 @@ ExpressionTreeNode Operation::Atan2::differentiate(const std::vector<ExpressionT ...@@ -213,6 +281,8 @@ ExpressionTreeNode Operation::Atan2::differentiate(const std::vector<ExpressionT
} }
ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Cosh(), ExpressionTreeNode(new Operation::Cosh(),
children[0]), children[0]),
...@@ -220,6 +290,8 @@ ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTr ...@@ -220,6 +290,8 @@ ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTr
} }
ExpressionTreeNode Operation::Cosh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Cosh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Sinh(), ExpressionTreeNode(new Operation::Sinh(),
children[0]), children[0]),
...@@ -227,6 +299,8 @@ ExpressionTreeNode Operation::Cosh::differentiate(const std::vector<ExpressionTr ...@@ -227,6 +299,8 @@ ExpressionTreeNode Operation::Cosh::differentiate(const std::vector<ExpressionTr
} }
ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Subtract(), ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Constant(1.0)), ExpressionTreeNode(new Operation::Constant(1.0)),
...@@ -236,6 +310,8 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTr ...@@ -236,6 +310,8 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTr
} }
ExpressionTreeNode Operation::Erf::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Erf::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Constant(2.0/sqrt(M_PI))), ExpressionTreeNode(new Operation::Constant(2.0/sqrt(M_PI))),
...@@ -246,6 +322,8 @@ ExpressionTreeNode Operation::Erf::differentiate(const std::vector<ExpressionTre ...@@ -246,6 +322,8 @@ ExpressionTreeNode Operation::Erf::differentiate(const std::vector<ExpressionTre
} }
ExpressionTreeNode Operation::Erfc::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Erfc::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Constant(-2.0/sqrt(M_PI))), ExpressionTreeNode(new Operation::Constant(-2.0/sqrt(M_PI))),
...@@ -264,6 +342,8 @@ ExpressionTreeNode Operation::Delta::differentiate(const std::vector<ExpressionT ...@@ -264,6 +342,8 @@ ExpressionTreeNode Operation::Delta::differentiate(const std::vector<ExpressionT
} }
ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(2.0), ExpressionTreeNode(new Operation::MultiplyConstant(2.0),
children[0]), children[0]),
...@@ -271,6 +351,8 @@ ExpressionTreeNode Operation::Square::differentiate(const std::vector<Expression ...@@ -271,6 +351,8 @@ ExpressionTreeNode Operation::Square::differentiate(const std::vector<Expression
} }
ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(3.0), ExpressionTreeNode(new Operation::MultiplyConstant(3.0),
ExpressionTreeNode(new Operation::Square(), children[0])), ExpressionTreeNode(new Operation::Square(), children[0])),
...@@ -278,6 +360,8 @@ ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTr ...@@ -278,6 +360,8 @@ ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTr
} }
ExpressionTreeNode Operation::Reciprocal::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Reciprocal::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Reciprocal(), ExpressionTreeNode(new Operation::Reciprocal(),
...@@ -290,11 +374,15 @@ ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector<Expre ...@@ -290,11 +374,15 @@ ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector<Expre
} }
ExpressionTreeNode Operation::MultiplyConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::MultiplyConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::MultiplyConstant(value), return ExpressionTreeNode(new Operation::MultiplyConstant(value),
childDerivs[0]); childDerivs[0]);
} }
ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(value), ExpressionTreeNode(new Operation::MultiplyConstant(value),
ExpressionTreeNode(new Operation::PowerConstant(value-1), ExpressionTreeNode(new Operation::PowerConstant(value-1),
...@@ -321,6 +409,8 @@ ExpressionTreeNode Operation::Max::differentiate(const std::vector<ExpressionTre ...@@ -321,6 +409,8 @@ ExpressionTreeNode Operation::Max::differentiate(const std::vector<ExpressionTre
} }
ExpressionTreeNode Operation::Abs::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const { ExpressionTreeNode Operation::Abs::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
ExpressionTreeNode step(new Operation::Step(), children[0]); ExpressionTreeNode step(new Operation::Step(), children[0]);
return ExpressionTreeNode(new Operation::Multiply(), return ExpressionTreeNode(new Operation::Multiply(),
childDerivs[0], childDerivs[0],
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009 Stanford University and the Authors. * * Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -68,9 +68,16 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map<stri ...@@ -68,9 +68,16 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map<stri
} }
ParsedExpression ParsedExpression::optimize() const { ParsedExpression ParsedExpression::optimize() const {
ExpressionTreeNode result = precalculateConstantSubexpressions(getRootNode()); ExpressionTreeNode result = getRootNode();
vector<const ExpressionTreeNode*> examples;
result.assignTags(examples);
map<int, ExpressionTreeNode> nodeCache;
result = precalculateConstantSubexpressions(result, nodeCache);
while (true) { while (true) {
ExpressionTreeNode simplified = substituteSimplerExpression(result); examples.clear();
result.assignTags(examples);
nodeCache.clear();
ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache);
if (simplified == result) if (simplified == result)
break; break;
result = simplified; result = simplified;
...@@ -80,9 +87,15 @@ ParsedExpression ParsedExpression::optimize() const { ...@@ -80,9 +87,15 @@ ParsedExpression ParsedExpression::optimize() const {
ParsedExpression ParsedExpression::optimize(const map<string, double>& variables) const { ParsedExpression ParsedExpression::optimize(const map<string, double>& variables) const {
ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables); ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables);
result = precalculateConstantSubexpressions(result); vector<const ExpressionTreeNode*> examples;
result.assignTags(examples);
map<int, ExpressionTreeNode> nodeCache;
result = precalculateConstantSubexpressions(result, nodeCache);
while (true) { while (true) {
ExpressionTreeNode simplified = substituteSimplerExpression(result); examples.clear();
result.assignTags(examples);
nodeCache.clear();
ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache);
if (simplified == result) if (simplified == result)
break; break;
result = simplified; result = simplified;
...@@ -104,23 +117,40 @@ ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNo ...@@ -104,23 +117,40 @@ ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNo
return ExpressionTreeNode(node.getOperation().clone(), children); return ExpressionTreeNode(node.getOperation().clone(), children);
} }
ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node) { ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node, map<int, ExpressionTreeNode>& nodeCache) {
auto cached = nodeCache.find(node.tag);
if (cached != nodeCache.end())
return cached->second;
vector<ExpressionTreeNode> children(node.getChildren().size()); vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < (int) children.size(); i++) for (int i = 0; i < (int) children.size(); i++)
children[i] = precalculateConstantSubexpressions(node.getChildren()[i]); children[i] = precalculateConstantSubexpressions(node.getChildren()[i], nodeCache);
ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children); ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children);
if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) {
nodeCache[node.tag] = result;
return result; return result;
}
for (int i = 0; i < (int) children.size(); i++) for (int i = 0; i < (int) children.size(); i++)
if (children[i].getOperation().getId() != Operation::CONSTANT) if (children[i].getOperation().getId() != Operation::CONSTANT) {
nodeCache[node.tag] = result;
return result;
}
result = ExpressionTreeNode(new Operation::Constant(evaluate(result, map<string, double>())));
nodeCache[node.tag] = result;
return result; return result;
return ExpressionTreeNode(new Operation::Constant(evaluate(result, map<string, double>())));
} }
ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node) { ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node, map<int, ExpressionTreeNode>& nodeCache) {
vector<ExpressionTreeNode> children(node.getChildren().size()); vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < (int) children.size(); i++) for (int i = 0; i < (int) children.size(); i++) {
children[i] = substituteSimplerExpression(node.getChildren()[i]); const ExpressionTreeNode& child = node.getChildren()[i];
auto cached = nodeCache.find(child.tag);
if (cached == nodeCache.end()) {
children[i] = substituteSimplerExpression(child, nodeCache);
nodeCache[child.tag] = children[i];
}
else
children[i] = cached->second;
}
// Collect some info on constant expressions in children // Collect some info on constant expressions in children
bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant? bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant?
...@@ -306,14 +336,22 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio ...@@ -306,14 +336,22 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
} }
ParsedExpression ParsedExpression::differentiate(const string& variable) const { ParsedExpression ParsedExpression::differentiate(const string& variable) const {
return differentiate(getRootNode(), variable); vector<const ExpressionTreeNode*> examples;
getRootNode().assignTags(examples);
map<int, ExpressionTreeNode> nodeCache;
return differentiate(getRootNode(), variable, nodeCache);
} }
ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable) { ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable, map<int, ExpressionTreeNode>& nodeCache) {
auto cached = nodeCache.find(node.tag);
if (cached != nodeCache.end())
return cached->second;
vector<ExpressionTreeNode> childDerivs(node.getChildren().size()); vector<ExpressionTreeNode> childDerivs(node.getChildren().size());
for (int i = 0; i < (int) childDerivs.size(); i++) for (int i = 0; i < (int) childDerivs.size(); i++)
childDerivs[i] = differentiate(node.getChildren()[i], variable); childDerivs[i] = differentiate(node.getChildren()[i], variable, nodeCache);
return node.getOperation().differentiate(node.getChildren(),childDerivs, variable); ExpressionTreeNode result = node.getOperation().differentiate(node.getChildren(), childDerivs, variable);
nodeCache[node.tag] = result;
return result;
} }
bool ParsedExpression::isConstant(const ExpressionTreeNode& node) { bool ParsedExpression::isConstant(const ExpressionTreeNode& node) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment