"platforms/vscode:/vscode.git/clone" did not exist on "172d41e57485467498bbfe5ee8340624cb20114b"
Unverified Commit ac754c45 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Optimizations to Lepton (#3044)

* Optimizations to Lepton

* More optimizations to Lepton
parent a85c2428
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors. *
* Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -39,6 +39,7 @@
namespace Lepton {
class Operation;
class ParsedExpression;
/**
* This class represents a node in the abstract syntax tree representation of an expression.
......@@ -82,11 +83,13 @@ public:
*/
ExpressionTreeNode(Operation* operation);
ExpressionTreeNode(const ExpressionTreeNode& node);
ExpressionTreeNode(ExpressionTreeNode&& node);
ExpressionTreeNode();
~ExpressionTreeNode();
bool operator==(const ExpressionTreeNode& node) const;
bool operator!=(const ExpressionTreeNode& node) const;
ExpressionTreeNode& operator=(const ExpressionTreeNode& node);
ExpressionTreeNode& operator=(ExpressionTreeNode&& node);
/**
* Get the Operation performed by this node.
*/
......@@ -96,8 +99,11 @@ public:
*/
const std::vector<ExpressionTreeNode>& getChildren() const;
private:
friend class ParsedExpression;
void assignTags(std::vector<const ExpressionTreeNode*>& examples) const;
Operation* operation;
std::vector<ExpressionTreeNode> children;
mutable int tag;
};
} // namespace Lepton
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009=2013 Stanford University and the Authors. *
* Portions copyright (c) 2009=2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -113,9 +113,9 @@ public:
private:
static double evaluate(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
static ExpressionTreeNode preevaluateVariables(const ExpressionTreeNode& node, const std::map<std::string, double>& variables);
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node);
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node);
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable);
static ExpressionTreeNode precalculateConstantSubexpressions(const ExpressionTreeNode& node, std::map<int, ExpressionTreeNode>& nodeCache);
static ExpressionTreeNode substituteSimplerExpression(const ExpressionTreeNode& node, std::map<int, ExpressionTreeNode>& nodeCache);
static ExpressionTreeNode differentiate(const ExpressionTreeNode& node, const std::string& variable, std::map<int, ExpressionTreeNode>& nodeCache);
static bool isConstant(const ExpressionTreeNode& node);
static double getConstantValue(const ExpressionTreeNode& node);
static ExpressionTreeNode renameNodeVariables(const ExpressionTreeNode& node, const std::map<std::string, std::string>& replacements);
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -32,6 +32,7 @@
#include "lepton/ExpressionTreeNode.h"
#include "lepton/Exception.h"
#include "lepton/Operation.h"
#include <utility>
using namespace Lepton;
using namespace std;
......@@ -62,6 +63,11 @@ ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operati
ExpressionTreeNode::ExpressionTreeNode(const ExpressionTreeNode& node) : operation(node.operation == NULL ? NULL : node.operation->clone()), children(node.getChildren()) {
}
ExpressionTreeNode::ExpressionTreeNode(ExpressionTreeNode&& node) : operation(node.operation), children(move(node.children)) {
node.operation = NULL;
node.children.clear();
}
ExpressionTreeNode::ExpressionTreeNode() : operation(NULL) {
}
......@@ -98,6 +104,16 @@ ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node
return *this;
}
ExpressionTreeNode& ExpressionTreeNode::operator=(ExpressionTreeNode&& node) {
if (operation != NULL)
delete operation;
operation = node.operation;
children = move(node.children);
node.operation = NULL;
node.children.clear();
return *this;
}
const Operation& ExpressionTreeNode::getOperation() const {
return *operation;
}
......@@ -105,3 +121,33 @@ const Operation& ExpressionTreeNode::getOperation() const {
const vector<ExpressionTreeNode>& ExpressionTreeNode::getChildren() const {
return children;
}
void ExpressionTreeNode::assignTags(vector<const ExpressionTreeNode*>& examples) const {
// Assign tag values to all nodes in a tree, such that two nodes have the same
// tag if and only if they (and all their children) are equal. This is used to
// optimize other operations.
int numTags = examples.size();
for (const ExpressionTreeNode& child : getChildren())
child.assignTags(examples);
if (numTags == examples.size()) {
// All the children matched existing tags, so possibly this node does too.
for (int i = 0; i < examples.size(); i++) {
const ExpressionTreeNode& example = *examples[i];
bool matches = (getChildren().size() == example.getChildren().size() && getOperation() == example.getOperation());
for (int j = 0; matches && j < getChildren().size(); j++)
if (getChildren()[j].tag != example.getChildren()[j].tag)
matches = false;
if (matches) {
tag = i;
return;
}
}
}
// This node does not match any previous node, so assign a new tag.
tag = examples.size();
examples.push_back(this);
}
......@@ -37,6 +37,12 @@
using namespace Lepton;
using namespace std;
static bool isZero(const ExpressionTreeNode& node) {
if (node.getOperation().getId() != Operation::CONSTANT)
return false;
return dynamic_cast<const Operation::Constant&>(node.getOperation()).getValue() == 0.0;
}
double Operation::Erf::evaluate(double* args, const map<string, double>& variables) const {
return erf(args[0]);
}
......@@ -58,35 +64,71 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector<Expressi
ExpressionTreeNode Operation::Custom::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (function->getNumArguments() == 0)
return ExpressionTreeNode(new Operation::Constant(0.0));
ExpressionTreeNode result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, 0), children), childDerivs[0]);
for (int i = 1; i < getNumArguments(); i++) {
ExpressionTreeNode result;
bool foundTerm = false;
for (int i = 0; i < getNumArguments(); i++) {
if (!isZero(childDerivs[i])) {
if (foundTerm)
result = ExpressionTreeNode(new Operation::Add(),
result,
ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]));
else {
result = ExpressionTreeNode(new Operation::Multiply(), ExpressionTreeNode(new Operation::Custom(*this, i), children), childDerivs[i]);
foundTerm = true;
}
}
}
if (foundTerm)
return result;
return ExpressionTreeNode(new Operation::Constant(0.0));
}
ExpressionTreeNode Operation::Add::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return childDerivs[1];
if (isZero(childDerivs[1]))
return childDerivs[0];
return ExpressionTreeNode(new Operation::Add(), childDerivs[0], childDerivs[1]);
}
ExpressionTreeNode Operation::Subtract::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0])) {
if (isZero(childDerivs[1]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Negate(), childDerivs[1]);
}
if (isZero(childDerivs[1]))
return childDerivs[0];
return ExpressionTreeNode(new Operation::Subtract(), childDerivs[0], childDerivs[1]);
}
ExpressionTreeNode Operation::Multiply::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0])) {
if (isZero(childDerivs[1]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]);
}
if (isZero(childDerivs[1]))
return ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]);
return ExpressionTreeNode(new Operation::Add(),
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]),
ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]));
}
ExpressionTreeNode Operation::Divide::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
return ExpressionTreeNode(new Operation::Divide(),
ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode subexp;
if (isZero(childDerivs[0])) {
if (isZero(childDerivs[1]))
return ExpressionTreeNode(new Operation::Constant(0.0));
subexp = ExpressionTreeNode(new Operation::Negate(), ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]));
}
else if (isZero(childDerivs[1]))
subexp = ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]);
else
subexp = ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Multiply(), children[1], childDerivs[0]),
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1])),
ExpressionTreeNode(new Operation::Square(), children[1]));
ExpressionTreeNode(new Operation::Multiply(), children[0], childDerivs[1]));
return ExpressionTreeNode(new Operation::Divide(), subexp, ExpressionTreeNode(new Operation::Square(), children[1]));
}
ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
......@@ -105,10 +147,14 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionT
}
ExpressionTreeNode Operation::Negate::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Negate(), childDerivs[0]);
}
ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(0.5),
ExpressionTreeNode(new Operation::Reciprocal(),
......@@ -117,24 +163,32 @@ ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Exp::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Exp(), children[0]),
childDerivs[0]);
}
ExpressionTreeNode Operation::Log::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(), children[0]),
childDerivs[0]);
}
ExpressionTreeNode Operation::Sin::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Cos(), children[0]),
childDerivs[0]);
}
ExpressionTreeNode Operation::Cos::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Sin(), children[0])),
......@@ -142,6 +196,8 @@ ExpressionTreeNode Operation::Cos::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Sec::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Sec(), children[0]),
......@@ -150,6 +206,8 @@ ExpressionTreeNode Operation::Sec::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Csc::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Multiply(),
......@@ -159,6 +217,8 @@ ExpressionTreeNode Operation::Csc::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Tan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Square(),
ExpressionTreeNode(new Operation::Sec(), children[0])),
......@@ -166,6 +226,8 @@ ExpressionTreeNode Operation::Tan::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Cot::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Square(),
......@@ -174,6 +236,8 @@ ExpressionTreeNode Operation::Cot::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Asin::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::Sqrt(),
......@@ -184,6 +248,8 @@ ExpressionTreeNode Operation::Asin::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Reciprocal(),
......@@ -195,6 +261,8 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Atan::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Reciprocal(),
ExpressionTreeNode(new Operation::AddConstant(1.0),
......@@ -213,6 +281,8 @@ ExpressionTreeNode Operation::Atan2::differentiate(const std::vector<ExpressionT
}
ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Cosh(),
children[0]),
......@@ -220,6 +290,8 @@ ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Cosh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Sinh(),
children[0]),
......@@ -227,6 +299,8 @@ ExpressionTreeNode Operation::Cosh::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Subtract(),
ExpressionTreeNode(new Operation::Constant(1.0)),
......@@ -236,6 +310,8 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Erf::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Constant(2.0/sqrt(M_PI))),
......@@ -246,6 +322,8 @@ ExpressionTreeNode Operation::Erf::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Erfc::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Constant(-2.0/sqrt(M_PI))),
......@@ -264,6 +342,8 @@ ExpressionTreeNode Operation::Delta::differentiate(const std::vector<ExpressionT
}
ExpressionTreeNode Operation::Square::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(2.0),
children[0]),
......@@ -271,6 +351,8 @@ ExpressionTreeNode Operation::Square::differentiate(const std::vector<Expression
}
ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(3.0),
ExpressionTreeNode(new Operation::Square(), children[0])),
......@@ -278,6 +360,8 @@ ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode Operation::Reciprocal::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::Negate(),
ExpressionTreeNode(new Operation::Reciprocal(),
......@@ -290,11 +374,15 @@ ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector<Expre
}
ExpressionTreeNode Operation::MultiplyConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::MultiplyConstant(value),
childDerivs[0]);
}
ExpressionTreeNode Operation::PowerConstant::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
return ExpressionTreeNode(new Operation::Multiply(),
ExpressionTreeNode(new Operation::MultiplyConstant(value),
ExpressionTreeNode(new Operation::PowerConstant(value-1),
......@@ -321,6 +409,8 @@ ExpressionTreeNode Operation::Max::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode Operation::Abs::differentiate(const std::vector<ExpressionTreeNode>& children, const std::vector<ExpressionTreeNode>& childDerivs, const std::string& variable) const {
if (isZero(childDerivs[0]))
return ExpressionTreeNode(new Operation::Constant(0.0));
ExpressionTreeNode step(new Operation::Step(), children[0]);
return ExpressionTreeNode(new Operation::Multiply(),
childDerivs[0],
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors. *
* Portions copyright (c) 2009-2021 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -68,9 +68,16 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map<stri
}
ParsedExpression ParsedExpression::optimize() const {
ExpressionTreeNode result = precalculateConstantSubexpressions(getRootNode());
ExpressionTreeNode result = getRootNode();
vector<const ExpressionTreeNode*> examples;
result.assignTags(examples);
map<int, ExpressionTreeNode> nodeCache;
result = precalculateConstantSubexpressions(result, nodeCache);
while (true) {
ExpressionTreeNode simplified = substituteSimplerExpression(result);
examples.clear();
result.assignTags(examples);
nodeCache.clear();
ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache);
if (simplified == result)
break;
result = simplified;
......@@ -80,9 +87,15 @@ ParsedExpression ParsedExpression::optimize() const {
ParsedExpression ParsedExpression::optimize(const map<string, double>& variables) const {
ExpressionTreeNode result = preevaluateVariables(getRootNode(), variables);
result = precalculateConstantSubexpressions(result);
vector<const ExpressionTreeNode*> examples;
result.assignTags(examples);
map<int, ExpressionTreeNode> nodeCache;
result = precalculateConstantSubexpressions(result, nodeCache);
while (true) {
ExpressionTreeNode simplified = substituteSimplerExpression(result);
examples.clear();
result.assignTags(examples);
nodeCache.clear();
ExpressionTreeNode simplified = substituteSimplerExpression(result, nodeCache);
if (simplified == result)
break;
result = simplified;
......@@ -104,23 +117,40 @@ ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNo
return ExpressionTreeNode(node.getOperation().clone(), children);
}
ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node) {
ExpressionTreeNode ParsedExpression::precalculateConstantSubexpressions(const ExpressionTreeNode& node, map<int, ExpressionTreeNode>& nodeCache) {
auto cached = nodeCache.find(node.tag);
if (cached != nodeCache.end())
return cached->second;
vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < (int) children.size(); i++)
children[i] = precalculateConstantSubexpressions(node.getChildren()[i]);
children[i] = precalculateConstantSubexpressions(node.getChildren()[i], nodeCache);
ExpressionTreeNode result = ExpressionTreeNode(node.getOperation().clone(), children);
if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM)
if (node.getOperation().getId() == Operation::VARIABLE || node.getOperation().getId() == Operation::CUSTOM) {
nodeCache[node.tag] = result;
return result;
}
for (int i = 0; i < (int) children.size(); i++)
if (children[i].getOperation().getId() != Operation::CONSTANT)
if (children[i].getOperation().getId() != Operation::CONSTANT) {
nodeCache[node.tag] = result;
return result;
}
result = ExpressionTreeNode(new Operation::Constant(evaluate(result, map<string, double>())));
nodeCache[node.tag] = result;
return result;
return ExpressionTreeNode(new Operation::Constant(evaluate(result, map<string, double>())));
}
ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node) {
ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const ExpressionTreeNode& node, map<int, ExpressionTreeNode>& nodeCache) {
vector<ExpressionTreeNode> children(node.getChildren().size());
for (int i = 0; i < (int) children.size(); i++)
children[i] = substituteSimplerExpression(node.getChildren()[i]);
for (int i = 0; i < (int) children.size(); i++) {
const ExpressionTreeNode& child = node.getChildren()[i];
auto cached = nodeCache.find(child.tag);
if (cached == nodeCache.end()) {
children[i] = substituteSimplerExpression(child, nodeCache);
nodeCache[child.tag] = children[i];
}
else
children[i] = cached->second;
}
// Collect some info on constant expressions in children
bool first_const = children.size() > 0 && isConstant(children[0]); // is first child constant?
......@@ -306,14 +336,22 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
}
ParsedExpression ParsedExpression::differentiate(const string& variable) const {
return differentiate(getRootNode(), variable);
vector<const ExpressionTreeNode*> examples;
getRootNode().assignTags(examples);
map<int, ExpressionTreeNode> nodeCache;
return differentiate(getRootNode(), variable, nodeCache);
}
ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable) {
ExpressionTreeNode ParsedExpression::differentiate(const ExpressionTreeNode& node, const string& variable, map<int, ExpressionTreeNode>& nodeCache) {
auto cached = nodeCache.find(node.tag);
if (cached != nodeCache.end())
return cached->second;
vector<ExpressionTreeNode> childDerivs(node.getChildren().size());
for (int i = 0; i < (int) childDerivs.size(); i++)
childDerivs[i] = differentiate(node.getChildren()[i], variable);
return node.getOperation().differentiate(node.getChildren(),childDerivs, variable);
childDerivs[i] = differentiate(node.getChildren()[i], variable, nodeCache);
ExpressionTreeNode result = node.getOperation().differentiate(node.getChildren(), childDerivs, variable);
nodeCache[node.tag] = result;
return result;
}
bool ParsedExpression::isConstant(const ExpressionTreeNode& node) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment