Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
ac754c45
Unverified
Commit
ac754c45
authored
Mar 02, 2021
by
Peter Eastman
Committed by
GitHub
Mar 02, 2021
Browse files
Optimizations to Lepton (#3044)
* Optimizations to Lepton * More optimizations to Lepton
parent
a85c2428
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
213 additions
and
33 deletions
+213
-33
libraries/lepton/include/lepton/ExpressionTreeNode.h
libraries/lepton/include/lepton/ExpressionTreeNode.h
+7
-1
libraries/lepton/include/lepton/ParsedExpression.h
libraries/lepton/include/lepton/ParsedExpression.h
+4
-4
libraries/lepton/src/ExpressionTreeNode.cpp
libraries/lepton/src/ExpressionTreeNode.cpp
+47
-1
libraries/lepton/src/Operation.cpp
libraries/lepton/src/Operation.cpp
+100
-10
libraries/lepton/src/ParsedExpression.cpp
libraries/lepton/src/ParsedExpression.cpp
+55
-17
No files found.
libraries/lepton/include/lepton/ExpressionTreeNode.h
View file @
ac754c45
...
...
@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors.
*
* Portions copyright (c) 2009
-2021
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
...
...
@@ -39,6 +39,7 @@
namespace
Lepton
{
class
Operation
;
class
ParsedExpression
;
/**
* This class represents a node in the abstract syntax tree representation of an expression.
...
...
@@ -82,11 +83,13 @@ public:
*/
ExpressionTreeNode
(
Operation
*
operation
);
ExpressionTreeNode
(
const
ExpressionTreeNode
&
node
);
ExpressionTreeNode
(
ExpressionTreeNode
&&
node
);
ExpressionTreeNode
();
~
ExpressionTreeNode
();
bool
operator
==
(
const
ExpressionTreeNode
&
node
)
const
;
bool
operator
!=
(
const
ExpressionTreeNode
&
node
)
const
;
ExpressionTreeNode
&
operator
=
(
const
ExpressionTreeNode
&
node
);
ExpressionTreeNode
&
operator
=
(
ExpressionTreeNode
&&
node
);
/**
* Get the Operation performed by this node.
*/
...
...
@@ -96,8 +99,11 @@ public:
*/
const
std
::
vector
<
ExpressionTreeNode
>&
getChildren
()
const
;
private:
friend
class
ParsedExpression
;
void
assignTags
(
std
::
vector
<
const
ExpressionTreeNode
*>&
examples
)
const
;
Operation
*
operation
;
std
::
vector
<
ExpressionTreeNode
>
children
;
mutable
int
tag
;
};
}
// namespace Lepton
...
...
libraries/lepton/include/lepton/ParsedExpression.h
View file @
ac754c45
...
...
@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009=201
3
Stanford University and the Authors. *
* Portions copyright (c) 2009=20
2
1 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
...
...
@@ -113,9 +113,9 @@ public:
private:
static
double
evaluate
(
const
ExpressionTreeNode
&
node
,
const
std
::
map
<
std
::
string
,
double
>&
variables
);
static
ExpressionTreeNode
preevaluateVariables
(
const
ExpressionTreeNode
&
node
,
const
std
::
map
<
std
::
string
,
double
>&
variables
);
static
ExpressionTreeNode
precalculateConstantSubexpressions
(
const
ExpressionTreeNode
&
node
);
static
ExpressionTreeNode
substituteSimplerExpression
(
const
ExpressionTreeNode
&
node
);
static
ExpressionTreeNode
differentiate
(
const
ExpressionTreeNode
&
node
,
const
std
::
string
&
variable
);
static
ExpressionTreeNode
precalculateConstantSubexpressions
(
const
ExpressionTreeNode
&
node
,
std
::
map
<
int
,
ExpressionTreeNode
>&
nodeCache
);
static
ExpressionTreeNode
substituteSimplerExpression
(
const
ExpressionTreeNode
&
node
,
std
::
map
<
int
,
ExpressionTreeNode
>&
nodeCache
);
static
ExpressionTreeNode
differentiate
(
const
ExpressionTreeNode
&
node
,
const
std
::
string
&
variable
,
std
::
map
<
int
,
ExpressionTreeNode
>&
nodeCache
);
static
bool
isConstant
(
const
ExpressionTreeNode
&
node
);
static
double
getConstantValue
(
const
ExpressionTreeNode
&
node
);
static
ExpressionTreeNode
renameNodeVariables
(
const
ExpressionTreeNode
&
node
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
replacements
);
...
...
libraries/lepton/src/ExpressionTreeNode.cpp
View file @
ac754c45
...
...
@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-201
5
Stanford University and the Authors. *
* Portions copyright (c) 2009-20
2
1 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
...
...
@@ -32,6 +32,7 @@
#include "lepton/ExpressionTreeNode.h"
#include "lepton/Exception.h"
#include "lepton/Operation.h"
#include <utility>
using
namespace
Lepton
;
using
namespace
std
;
...
...
@@ -62,6 +63,11 @@ ExpressionTreeNode::ExpressionTreeNode(Operation* operation) : operation(operati
ExpressionTreeNode
::
ExpressionTreeNode
(
const
ExpressionTreeNode
&
node
)
:
operation
(
node
.
operation
==
NULL
?
NULL
:
node
.
operation
->
clone
()),
children
(
node
.
getChildren
())
{
}
ExpressionTreeNode
::
ExpressionTreeNode
(
ExpressionTreeNode
&&
node
)
:
operation
(
node
.
operation
),
children
(
move
(
node
.
children
))
{
node
.
operation
=
NULL
;
node
.
children
.
clear
();
}
ExpressionTreeNode
::
ExpressionTreeNode
()
:
operation
(
NULL
)
{
}
...
...
@@ -98,6 +104,16 @@ ExpressionTreeNode& ExpressionTreeNode::operator=(const ExpressionTreeNode& node
return
*
this
;
}
ExpressionTreeNode
&
ExpressionTreeNode
::
operator
=
(
ExpressionTreeNode
&&
node
)
{
if
(
operation
!=
NULL
)
delete
operation
;
operation
=
node
.
operation
;
children
=
move
(
node
.
children
);
node
.
operation
=
NULL
;
node
.
children
.
clear
();
return
*
this
;
}
const
Operation
&
ExpressionTreeNode
::
getOperation
()
const
{
return
*
operation
;
}
...
...
@@ -105,3 +121,33 @@ const Operation& ExpressionTreeNode::getOperation() const {
const
vector
<
ExpressionTreeNode
>&
ExpressionTreeNode
::
getChildren
()
const
{
return
children
;
}
void
ExpressionTreeNode
::
assignTags
(
vector
<
const
ExpressionTreeNode
*>&
examples
)
const
{
// Assign tag values to all nodes in a tree, such that two nodes have the same
// tag if and only if they (and all their children) are equal. This is used to
// optimize other operations.
int
numTags
=
examples
.
size
();
for
(
const
ExpressionTreeNode
&
child
:
getChildren
())
child
.
assignTags
(
examples
);
if
(
numTags
==
examples
.
size
())
{
// All the children matched existing tags, so possibly this node does too.
for
(
int
i
=
0
;
i
<
examples
.
size
();
i
++
)
{
const
ExpressionTreeNode
&
example
=
*
examples
[
i
];
bool
matches
=
(
getChildren
().
size
()
==
example
.
getChildren
().
size
()
&&
getOperation
()
==
example
.
getOperation
());
for
(
int
j
=
0
;
matches
&&
j
<
getChildren
().
size
();
j
++
)
if
(
getChildren
()[
j
].
tag
!=
example
.
getChildren
()[
j
].
tag
)
matches
=
false
;
if
(
matches
)
{
tag
=
i
;
return
;
}
}
}
// This node does not match any previous node, so assign a new tag.
tag
=
examples
.
size
();
examples
.
push_back
(
this
);
}
libraries/lepton/src/Operation.cpp
View file @
ac754c45
...
...
@@ -37,6 +37,12 @@
using
namespace
Lepton
;
using
namespace
std
;
static
bool
isZero
(
const
ExpressionTreeNode
&
node
)
{
if
(
node
.
getOperation
().
getId
()
!=
Operation
::
CONSTANT
)
return
false
;
return
dynamic_cast
<
const
Operation
::
Constant
&>
(
node
.
getOperation
()).
getValue
()
==
0.0
;
}
double
Operation
::
Erf
::
evaluate
(
double
*
args
,
const
map
<
string
,
double
>&
variables
)
const
{
return
erf
(
args
[
0
]);
}
...
...
@@ -58,35 +64,71 @@ ExpressionTreeNode Operation::Variable::differentiate(const std::vector<Expressi
ExpressionTreeNode
Operation
::
Custom
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
function
->
getNumArguments
()
==
0
)
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
ExpressionTreeNode
result
=
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Custom
(
*
this
,
0
),
children
),
childDerivs
[
0
]);
for
(
int
i
=
1
;
i
<
getNumArguments
();
i
++
)
{
ExpressionTreeNode
result
;
bool
foundTerm
=
false
;
for
(
int
i
=
0
;
i
<
getNumArguments
();
i
++
)
{
if
(
!
isZero
(
childDerivs
[
i
]))
{
if
(
foundTerm
)
result
=
ExpressionTreeNode
(
new
Operation
::
Add
(),
result
,
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Custom
(
*
this
,
i
),
children
),
childDerivs
[
i
]));
else
{
result
=
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Custom
(
*
this
,
i
),
children
),
childDerivs
[
i
]);
foundTerm
=
true
;
}
}
}
if
(
foundTerm
)
return
result
;
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
}
ExpressionTreeNode
Operation
::
Add
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
childDerivs
[
1
];
if
(
isZero
(
childDerivs
[
1
]))
return
childDerivs
[
0
];
return
ExpressionTreeNode
(
new
Operation
::
Add
(),
childDerivs
[
0
],
childDerivs
[
1
]);
}
ExpressionTreeNode
Operation
::
Subtract
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
{
if
(
isZero
(
childDerivs
[
1
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Negate
(),
childDerivs
[
1
]);
}
if
(
isZero
(
childDerivs
[
1
]))
return
childDerivs
[
0
];
return
ExpressionTreeNode
(
new
Operation
::
Subtract
(),
childDerivs
[
0
],
childDerivs
[
1
]);
}
ExpressionTreeNode
Operation
::
Multiply
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
{
if
(
isZero
(
childDerivs
[
1
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
children
[
0
],
childDerivs
[
1
]);
}
if
(
isZero
(
childDerivs
[
1
]))
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
children
[
1
],
childDerivs
[
0
]);
return
ExpressionTreeNode
(
new
Operation
::
Add
(),
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
children
[
0
],
childDerivs
[
1
]),
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
children
[
1
],
childDerivs
[
0
]));
}
ExpressionTreeNode
Operation
::
Divide
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
return
ExpressionTreeNode
(
new
Operation
::
Divide
(),
ExpressionTreeNode
(
new
Operation
::
Subtract
(),
ExpressionTreeNode
subexp
;
if
(
isZero
(
childDerivs
[
0
]))
{
if
(
isZero
(
childDerivs
[
1
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
subexp
=
ExpressionTreeNode
(
new
Operation
::
Negate
(),
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
children
[
0
],
childDerivs
[
1
]));
}
else
if
(
isZero
(
childDerivs
[
1
]))
subexp
=
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
children
[
1
],
childDerivs
[
0
]);
else
subexp
=
ExpressionTreeNode
(
new
Operation
::
Subtract
(),
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
children
[
1
],
childDerivs
[
0
]),
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
children
[
0
],
childDerivs
[
1
]))
,
ExpressionTreeNode
(
new
Operation
::
Square
(),
children
[
1
]));
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
children
[
0
],
childDerivs
[
1
]))
;
return
ExpressionTreeNode
(
new
Operation
::
Divide
(),
subexp
,
ExpressionTreeNode
(
new
Operation
::
Square
(),
children
[
1
]));
}
ExpressionTreeNode
Operation
::
Power
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
...
...
@@ -105,10 +147,14 @@ ExpressionTreeNode Operation::Power::differentiate(const std::vector<ExpressionT
}
ExpressionTreeNode
Operation
::
Negate
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Negate
(),
childDerivs
[
0
]);
}
ExpressionTreeNode
Operation
::
Sqrt
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
MultiplyConstant
(
0.5
),
ExpressionTreeNode
(
new
Operation
::
Reciprocal
(),
...
...
@@ -117,24 +163,32 @@ ExpressionTreeNode Operation::Sqrt::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode
Operation
::
Exp
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Exp
(),
children
[
0
]),
childDerivs
[
0
]);
}
ExpressionTreeNode
Operation
::
Log
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Reciprocal
(),
children
[
0
]),
childDerivs
[
0
]);
}
ExpressionTreeNode
Operation
::
Sin
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Cos
(),
children
[
0
]),
childDerivs
[
0
]);
}
ExpressionTreeNode
Operation
::
Cos
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Negate
(),
ExpressionTreeNode
(
new
Operation
::
Sin
(),
children
[
0
])),
...
...
@@ -142,6 +196,8 @@ ExpressionTreeNode Operation::Cos::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode
Operation
::
Sec
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Sec
(),
children
[
0
]),
...
...
@@ -150,6 +206,8 @@ ExpressionTreeNode Operation::Sec::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode
Operation
::
Csc
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Negate
(),
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
...
...
@@ -159,6 +217,8 @@ ExpressionTreeNode Operation::Csc::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode
Operation
::
Tan
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Square
(),
ExpressionTreeNode
(
new
Operation
::
Sec
(),
children
[
0
])),
...
...
@@ -166,6 +226,8 @@ ExpressionTreeNode Operation::Tan::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode
Operation
::
Cot
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Negate
(),
ExpressionTreeNode
(
new
Operation
::
Square
(),
...
...
@@ -174,6 +236,8 @@ ExpressionTreeNode Operation::Cot::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode
Operation
::
Asin
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Reciprocal
(),
ExpressionTreeNode
(
new
Operation
::
Sqrt
(),
...
...
@@ -184,6 +248,8 @@ ExpressionTreeNode Operation::Asin::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode
Operation
::
Acos
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Negate
(),
ExpressionTreeNode
(
new
Operation
::
Reciprocal
(),
...
...
@@ -195,6 +261,8 @@ ExpressionTreeNode Operation::Acos::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode
Operation
::
Atan
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Reciprocal
(),
ExpressionTreeNode
(
new
Operation
::
AddConstant
(
1.0
),
...
...
@@ -213,6 +281,8 @@ ExpressionTreeNode Operation::Atan2::differentiate(const std::vector<ExpressionT
}
ExpressionTreeNode
Operation
::
Sinh
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Cosh
(),
children
[
0
]),
...
...
@@ -220,6 +290,8 @@ ExpressionTreeNode Operation::Sinh::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode
Operation
::
Cosh
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Sinh
(),
children
[
0
]),
...
...
@@ -227,6 +299,8 @@ ExpressionTreeNode Operation::Cosh::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode
Operation
::
Tanh
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Subtract
(),
ExpressionTreeNode
(
new
Operation
::
Constant
(
1.0
)),
...
...
@@ -236,6 +310,8 @@ ExpressionTreeNode Operation::Tanh::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode
Operation
::
Erf
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Constant
(
2.0
/
sqrt
(
M_PI
))),
...
...
@@ -246,6 +322,8 @@ ExpressionTreeNode Operation::Erf::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode
Operation
::
Erfc
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Constant
(
-
2.0
/
sqrt
(
M_PI
))),
...
...
@@ -264,6 +342,8 @@ ExpressionTreeNode Operation::Delta::differentiate(const std::vector<ExpressionT
}
ExpressionTreeNode
Operation
::
Square
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
MultiplyConstant
(
2.0
),
children
[
0
]),
...
...
@@ -271,6 +351,8 @@ ExpressionTreeNode Operation::Square::differentiate(const std::vector<Expression
}
ExpressionTreeNode
Operation
::
Cube
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
MultiplyConstant
(
3.0
),
ExpressionTreeNode
(
new
Operation
::
Square
(),
children
[
0
])),
...
...
@@ -278,6 +360,8 @@ ExpressionTreeNode Operation::Cube::differentiate(const std::vector<ExpressionTr
}
ExpressionTreeNode
Operation
::
Reciprocal
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
Negate
(),
ExpressionTreeNode
(
new
Operation
::
Reciprocal
(),
...
...
@@ -290,11 +374,15 @@ ExpressionTreeNode Operation::AddConstant::differentiate(const std::vector<Expre
}
ExpressionTreeNode
Operation
::
MultiplyConstant
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
MultiplyConstant
(
value
),
childDerivs
[
0
]);
}
ExpressionTreeNode
Operation
::
PowerConstant
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
ExpressionTreeNode
(
new
Operation
::
MultiplyConstant
(
value
),
ExpressionTreeNode
(
new
Operation
::
PowerConstant
(
value
-
1
),
...
...
@@ -321,6 +409,8 @@ ExpressionTreeNode Operation::Max::differentiate(const std::vector<ExpressionTre
}
ExpressionTreeNode
Operation
::
Abs
::
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
{
if
(
isZero
(
childDerivs
[
0
]))
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
0.0
));
ExpressionTreeNode
step
(
new
Operation
::
Step
(),
children
[
0
]);
return
ExpressionTreeNode
(
new
Operation
::
Multiply
(),
childDerivs
[
0
],
...
...
libraries/lepton/src/ParsedExpression.cpp
View file @
ac754c45
...
...
@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009 Stanford University and the Authors.
*
* Portions copyright (c) 2009
-2021
Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
...
...
@@ -68,9 +68,16 @@ double ParsedExpression::evaluate(const ExpressionTreeNode& node, const map<stri
}
ParsedExpression
ParsedExpression
::
optimize
()
const
{
ExpressionTreeNode
result
=
precalculateConstantSubexpressions
(
getRootNode
());
ExpressionTreeNode
result
=
getRootNode
();
vector
<
const
ExpressionTreeNode
*>
examples
;
result
.
assignTags
(
examples
);
map
<
int
,
ExpressionTreeNode
>
nodeCache
;
result
=
precalculateConstantSubexpressions
(
result
,
nodeCache
);
while
(
true
)
{
ExpressionTreeNode
simplified
=
substituteSimplerExpression
(
result
);
examples
.
clear
();
result
.
assignTags
(
examples
);
nodeCache
.
clear
();
ExpressionTreeNode
simplified
=
substituteSimplerExpression
(
result
,
nodeCache
);
if
(
simplified
==
result
)
break
;
result
=
simplified
;
...
...
@@ -80,9 +87,15 @@ ParsedExpression ParsedExpression::optimize() const {
ParsedExpression
ParsedExpression
::
optimize
(
const
map
<
string
,
double
>&
variables
)
const
{
ExpressionTreeNode
result
=
preevaluateVariables
(
getRootNode
(),
variables
);
result
=
precalculateConstantSubexpressions
(
result
);
vector
<
const
ExpressionTreeNode
*>
examples
;
result
.
assignTags
(
examples
);
map
<
int
,
ExpressionTreeNode
>
nodeCache
;
result
=
precalculateConstantSubexpressions
(
result
,
nodeCache
);
while
(
true
)
{
ExpressionTreeNode
simplified
=
substituteSimplerExpression
(
result
);
examples
.
clear
();
result
.
assignTags
(
examples
);
nodeCache
.
clear
();
ExpressionTreeNode
simplified
=
substituteSimplerExpression
(
result
,
nodeCache
);
if
(
simplified
==
result
)
break
;
result
=
simplified
;
...
...
@@ -104,23 +117,40 @@ ExpressionTreeNode ParsedExpression::preevaluateVariables(const ExpressionTreeNo
return
ExpressionTreeNode
(
node
.
getOperation
().
clone
(),
children
);
}
ExpressionTreeNode
ParsedExpression
::
precalculateConstantSubexpressions
(
const
ExpressionTreeNode
&
node
)
{
ExpressionTreeNode
ParsedExpression
::
precalculateConstantSubexpressions
(
const
ExpressionTreeNode
&
node
,
map
<
int
,
ExpressionTreeNode
>&
nodeCache
)
{
auto
cached
=
nodeCache
.
find
(
node
.
tag
);
if
(
cached
!=
nodeCache
.
end
())
return
cached
->
second
;
vector
<
ExpressionTreeNode
>
children
(
node
.
getChildren
().
size
());
for
(
int
i
=
0
;
i
<
(
int
)
children
.
size
();
i
++
)
children
[
i
]
=
precalculateConstantSubexpressions
(
node
.
getChildren
()[
i
]);
children
[
i
]
=
precalculateConstantSubexpressions
(
node
.
getChildren
()[
i
]
,
nodeCache
);
ExpressionTreeNode
result
=
ExpressionTreeNode
(
node
.
getOperation
().
clone
(),
children
);
if
(
node
.
getOperation
().
getId
()
==
Operation
::
VARIABLE
||
node
.
getOperation
().
getId
()
==
Operation
::
CUSTOM
)
if
(
node
.
getOperation
().
getId
()
==
Operation
::
VARIABLE
||
node
.
getOperation
().
getId
()
==
Operation
::
CUSTOM
)
{
nodeCache
[
node
.
tag
]
=
result
;
return
result
;
}
for
(
int
i
=
0
;
i
<
(
int
)
children
.
size
();
i
++
)
if
(
children
[
i
].
getOperation
().
getId
()
!=
Operation
::
CONSTANT
)
if
(
children
[
i
].
getOperation
().
getId
()
!=
Operation
::
CONSTANT
)
{
nodeCache
[
node
.
tag
]
=
result
;
return
result
;
}
result
=
ExpressionTreeNode
(
new
Operation
::
Constant
(
evaluate
(
result
,
map
<
string
,
double
>
())));
nodeCache
[
node
.
tag
]
=
result
;
return
result
;
return
ExpressionTreeNode
(
new
Operation
::
Constant
(
evaluate
(
result
,
map
<
string
,
double
>
())));
}
ExpressionTreeNode
ParsedExpression
::
substituteSimplerExpression
(
const
ExpressionTreeNode
&
node
)
{
ExpressionTreeNode
ParsedExpression
::
substituteSimplerExpression
(
const
ExpressionTreeNode
&
node
,
map
<
int
,
ExpressionTreeNode
>&
nodeCache
)
{
vector
<
ExpressionTreeNode
>
children
(
node
.
getChildren
().
size
());
for
(
int
i
=
0
;
i
<
(
int
)
children
.
size
();
i
++
)
children
[
i
]
=
substituteSimplerExpression
(
node
.
getChildren
()[
i
]);
for
(
int
i
=
0
;
i
<
(
int
)
children
.
size
();
i
++
)
{
const
ExpressionTreeNode
&
child
=
node
.
getChildren
()[
i
];
auto
cached
=
nodeCache
.
find
(
child
.
tag
);
if
(
cached
==
nodeCache
.
end
())
{
children
[
i
]
=
substituteSimplerExpression
(
child
,
nodeCache
);
nodeCache
[
child
.
tag
]
=
children
[
i
];
}
else
children
[
i
]
=
cached
->
second
;
}
// Collect some info on constant expressions in children
bool
first_const
=
children
.
size
()
>
0
&&
isConstant
(
children
[
0
]);
// is first child constant?
...
...
@@ -306,14 +336,22 @@ ExpressionTreeNode ParsedExpression::substituteSimplerExpression(const Expressio
}
ParsedExpression
ParsedExpression
::
differentiate
(
const
string
&
variable
)
const
{
return
differentiate
(
getRootNode
(),
variable
);
vector
<
const
ExpressionTreeNode
*>
examples
;
getRootNode
().
assignTags
(
examples
);
map
<
int
,
ExpressionTreeNode
>
nodeCache
;
return
differentiate
(
getRootNode
(),
variable
,
nodeCache
);
}
ExpressionTreeNode
ParsedExpression
::
differentiate
(
const
ExpressionTreeNode
&
node
,
const
string
&
variable
)
{
ExpressionTreeNode
ParsedExpression
::
differentiate
(
const
ExpressionTreeNode
&
node
,
const
string
&
variable
,
map
<
int
,
ExpressionTreeNode
>&
nodeCache
)
{
auto
cached
=
nodeCache
.
find
(
node
.
tag
);
if
(
cached
!=
nodeCache
.
end
())
return
cached
->
second
;
vector
<
ExpressionTreeNode
>
childDerivs
(
node
.
getChildren
().
size
());
for
(
int
i
=
0
;
i
<
(
int
)
childDerivs
.
size
();
i
++
)
childDerivs
[
i
]
=
differentiate
(
node
.
getChildren
()[
i
],
variable
);
return
node
.
getOperation
().
differentiate
(
node
.
getChildren
(),
childDerivs
,
variable
);
childDerivs
[
i
]
=
differentiate
(
node
.
getChildren
()[
i
],
variable
,
nodeCache
);
ExpressionTreeNode
result
=
node
.
getOperation
().
differentiate
(
node
.
getChildren
(),
childDerivs
,
variable
);
nodeCache
[
node
.
tag
]
=
result
;
return
result
;
}
bool
ParsedExpression
::
isConstant
(
const
ExpressionTreeNode
&
node
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment