Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
4a25dc79
Commit
4a25dc79
authored
Sep 16, 2014
by
peastman
Browse files
Added more optimized operations to JIT compilation
parent
a6fe70d2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
109 additions
and
21 deletions
+109
-21
libraries/lepton/include/lepton/CompiledExpression.h
libraries/lepton/include/lepton/CompiledExpression.h
+1
-0
libraries/lepton/src/CompiledExpression.cpp
libraries/lepton/src/CompiledExpression.cpp
+75
-12
tests/TestParser.cpp
tests/TestParser.cpp
+33
-9
No files found.
libraries/lepton/include/lepton/CompiledExpression.h
View file @
4a25dc79
...
@@ -80,6 +80,7 @@ private:
...
@@ -80,6 +80,7 @@ private:
CompiledExpression
(
const
ParsedExpression
&
expression
);
CompiledExpression
(
const
ParsedExpression
&
expression
);
void
compileExpression
(
const
ExpressionTreeNode
&
node
,
std
::
vector
<
std
::
pair
<
ExpressionTreeNode
,
int
>
>&
temps
);
void
compileExpression
(
const
ExpressionTreeNode
&
node
,
std
::
vector
<
std
::
pair
<
ExpressionTreeNode
,
int
>
>&
temps
);
void
generateJitCode
();
void
generateJitCode
();
void
generateSingleArgCall
(
asmjit
::
X86Compiler
&
c
,
asmjit
::
X86XmmVar
&
dest
,
asmjit
::
X86XmmVar
&
arg
,
double
(
*
function
)(
double
));
int
findTempIndex
(
const
ExpressionTreeNode
&
node
,
std
::
vector
<
std
::
pair
<
ExpressionTreeNode
,
int
>
>&
temps
);
int
findTempIndex
(
const
ExpressionTreeNode
&
node
,
std
::
vector
<
std
::
pair
<
ExpressionTreeNode
,
int
>
>&
temps
);
std
::
vector
<
std
::
vector
<
int
>
>
arguments
;
std
::
vector
<
std
::
vector
<
int
>
>
arguments
;
std
::
vector
<
int
>
target
;
std
::
vector
<
int
>
target
;
...
...
libraries/lepton/src/CompiledExpression.cpp
View file @
4a25dc79
...
@@ -195,6 +195,10 @@ void CompiledExpression::generateJitCode() {
...
@@ -195,6 +195,10 @@ void CompiledExpression::generateJitCode() {
value
=
dynamic_cast
<
Operation
::
MultiplyConstant
&>
(
op
).
getValue
();
value
=
dynamic_cast
<
Operation
::
MultiplyConstant
&>
(
op
).
getValue
();
else
if
(
op
.
getId
()
==
Operation
::
RECIPROCAL
)
else
if
(
op
.
getId
()
==
Operation
::
RECIPROCAL
)
value
=
1.0
;
value
=
1.0
;
else
if
(
op
.
getId
()
==
Operation
::
STEP
)
value
=
1.0
;
else
if
(
op
.
getId
()
==
Operation
::
DELTA
)
value
=
1.0
;
else
else
continue
;
continue
;
...
@@ -232,55 +236,106 @@ void CompiledExpression::generateJitCode() {
...
@@ -232,55 +236,106 @@ void CompiledExpression::generateJitCode() {
for
(
int
i
=
1
;
i
<
op
.
getNumArguments
();
i
++
)
for
(
int
i
=
1
;
i
<
op
.
getNumArguments
();
i
++
)
args
.
push_back
(
args
[
0
]
+
i
);
args
.
push_back
(
args
[
0
]
+
i
);
}
}
// Generate instructions to execute this operation.
switch
(
op
.
getId
())
{
switch
(
op
.
getId
())
{
case
Operation
::
CONSTANT
:
case
Operation
::
CONSTANT
:
c
.
movsd
(
workspaceVar
[
target
[
step
]],
constantVar
[
operationConstantIndex
[
step
]]);
c
.
movsd
(
workspaceVar
[
target
[
step
]],
constantVar
[
operationConstantIndex
[
step
]]);
break
;
break
;
case
Operation
::
ADD
:
case
Operation
::
ADD
:
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
addsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
1
]]);
c
.
addsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
1
]]);
break
;
break
;
case
Operation
::
SUBTRACT
:
case
Operation
::
SUBTRACT
:
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
subsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
1
]]);
c
.
subsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
1
]]);
break
;
break
;
case
Operation
::
MULTIPLY
:
case
Operation
::
MULTIPLY
:
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
mulsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
1
]]);
c
.
mulsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
1
]]);
break
;
break
;
case
Operation
::
DIVIDE
:
case
Operation
::
DIVIDE
:
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
divsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
1
]]);
c
.
divsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
1
]]);
break
;
break
;
case
Operation
::
NEGATE
:
case
Operation
::
NEGATE
:
c
.
xorps
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
target
[
step
]]);
c
.
xorps
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
target
[
step
]]);
c
.
subsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
subsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
break
;
break
;
case
Operation
::
SQRT
:
case
Operation
::
SQRT
:
c
.
sqrtsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
sqrtsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
break
;
case
Operation
::
EXP
:
generateSingleArgCall
(
c
,
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
exp
);
break
;
case
Operation
::
LOG
:
generateSingleArgCall
(
c
,
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
log
);
break
;
case
Operation
::
SIN
:
generateSingleArgCall
(
c
,
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
sin
);
break
;
case
Operation
::
COS
:
generateSingleArgCall
(
c
,
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
cos
);
break
;
case
Operation
::
TAN
:
generateSingleArgCall
(
c
,
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
tan
);
break
;
case
Operation
::
ASIN
:
generateSingleArgCall
(
c
,
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
asin
);
break
;
case
Operation
::
ACOS
:
generateSingleArgCall
(
c
,
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
acos
);
break
;
case
Operation
::
ATAN
:
generateSingleArgCall
(
c
,
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
atan
);
break
;
case
Operation
::
SINH
:
generateSingleArgCall
(
c
,
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
sinh
);
break
;
case
Operation
::
COSH
:
generateSingleArgCall
(
c
,
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
cosh
);
break
;
case
Operation
::
TANH
:
generateSingleArgCall
(
c
,
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
tanh
);
break
;
case
Operation
::
STEP
:
c
.
xorps
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
target
[
step
]]);
c
.
cmpsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
imm
(
18
));
// Comparison mode is _CMP_LE_OQ = 18
c
.
andps
(
workspaceVar
[
target
[
step
]],
constantVar
[
operationConstantIndex
[
step
]]);
break
;
case
Operation
::
DELTA
:
c
.
xorps
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
target
[
step
]]);
c
.
cmpsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
imm
(
16
));
// Comparison mode is _CMP_EQ_OS = 16
c
.
andps
(
workspaceVar
[
target
[
step
]],
constantVar
[
operationConstantIndex
[
step
]]);
break
;
break
;
case
Operation
::
SQUARE
:
case
Operation
::
SQUARE
:
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
mulsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
mulsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
break
;
break
;
case
Operation
::
CUBE
:
case
Operation
::
CUBE
:
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
mulsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
mulsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
mulsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
mulsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
break
;
break
;
case
Operation
::
RECIPROCAL
:
case
Operation
::
RECIPROCAL
:
c
.
movsd
(
workspaceVar
[
target
[
step
]],
constantVar
[
operationConstantIndex
[
step
]]);
c
.
movsd
(
workspaceVar
[
target
[
step
]],
constantVar
[
operationConstantIndex
[
step
]]);
c
.
divsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
divsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
break
;
break
;
case
Operation
::
ADD_CONSTANT
:
case
Operation
::
ADD_CONSTANT
:
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
addsd
(
workspaceVar
[
target
[
step
]],
constantVar
[
operationConstantIndex
[
step
]]);
c
.
addsd
(
workspaceVar
[
target
[
step
]],
constantVar
[
operationConstantIndex
[
step
]]);
break
;
break
;
case
Operation
::
MULTIPLY_CONSTANT
:
case
Operation
::
MULTIPLY_CONSTANT
:
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
movsd
(
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]]);
c
.
mulsd
(
workspaceVar
[
target
[
step
]],
constantVar
[
operationConstantIndex
[
step
]]);
c
.
mulsd
(
workspaceVar
[
target
[
step
]],
constantVar
[
operationConstantIndex
[
step
]]);
break
;
case
Operation
::
ABS
:
generateSingleArgCall
(
c
,
workspaceVar
[
target
[
step
]],
workspaceVar
[
args
[
0
]],
fabs
);
break
;
break
;
default:
default:
// Just invoke evaluateOperation().
for
(
int
i
=
0
;
i
<
(
int
)
args
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
args
.
size
();
i
++
)
c
.
movsd
(
x86
::
ptr
(
argsPointer
,
8
*
i
,
0
),
workspaceVar
[
args
[
i
]]);
c
.
movsd
(
x86
::
ptr
(
argsPointer
,
8
*
i
,
0
),
workspaceVar
[
args
[
i
]]);
X86GpVar
fn
(
c
,
kVarTypeIntPtr
);
X86GpVar
fn
(
c
,
kVarTypeIntPtr
);
...
@@ -295,3 +350,11 @@ void CompiledExpression::generateJitCode() {
...
@@ -295,3 +350,11 @@ void CompiledExpression::generateJitCode() {
c
.
endFunc
();
c
.
endFunc
();
jitCode
=
c
.
make
();
jitCode
=
c
.
make
();
}
}
void
CompiledExpression
::
generateSingleArgCall
(
X86Compiler
&
c
,
X86XmmVar
&
dest
,
X86XmmVar
&
arg
,
double
(
*
function
)(
double
))
{
X86GpVar
fn
(
c
,
kVarTypeIntPtr
);
c
.
mov
(
fn
,
imm_ptr
((
void
*
)
function
));
X86CallNode
*
call
=
c
.
call
(
fn
,
kFuncConvHost
,
FuncBuilder1
<
double
,
double
>
());
call
->
setArg
(
0
,
arg
);
call
->
setRet
(
0
,
dest
);
}
tests/TestParser.cpp
View file @
4a25dc79
...
@@ -127,6 +127,18 @@ void verifyInvalidExpression(const string& expression) {
...
@@ -127,6 +127,18 @@ void verifyInvalidExpression(const string& expression) {
throw
exception
();
throw
exception
();
}
}
/**
* Verify that two numbers have the same value.
*/
void
assertNumbersEqual
(
double
val1
,
double
val2
)
{
const
double
inf
=
numeric_limits
<
double
>::
infinity
();
if
(
val1
==
val1
||
val2
==
val2
)
// If both are NaN, that's fine.
if
(
val1
!=
inf
||
val2
!=
inf
)
// Both infinity is also fine.
if
(
val1
!=
-
inf
||
val2
!=
-
inf
)
// Same for -infinity.
ASSERT_EQUAL_TOL
(
val1
,
val2
,
1e-10
);
}
/**
/**
* Verify that two expressions give the same value.
* Verify that two expressions give the same value.
*/
*/
...
@@ -137,11 +149,22 @@ void verifySameValue(const ParsedExpression& exp1, const ParsedExpression& exp2,
...
@@ -137,11 +149,22 @@ void verifySameValue(const ParsedExpression& exp1, const ParsedExpression& exp2,
variables
[
"y"
]
=
y
;
variables
[
"y"
]
=
y
;
double
val1
=
exp1
.
evaluate
(
variables
);
double
val1
=
exp1
.
evaluate
(
variables
);
double
val2
=
exp2
.
evaluate
(
variables
);
double
val2
=
exp2
.
evaluate
(
variables
);
const
double
inf
=
numeric_limits
<
double
>::
infinity
();
assertNumbersEqual
(
val1
,
val2
);
if
(
val1
==
val1
||
val2
==
val2
)
// If both are NaN, that's fine.
if
(
val1
!=
inf
||
val2
!=
inf
)
// Both infinity is also fine.
// Now create CompiledExpressions from them and see if those also match.
if
(
val1
!=
-
inf
||
val2
!=
-
inf
)
// Same for -infinity.
ASSERT_EQUAL_TOL
(
val1
,
val2
,
1e-10
);
CompiledExpression
compiled1
=
exp1
.
createCompiledExpression
();
CompiledExpression
compiled2
=
exp2
.
createCompiledExpression
();
if
(
compiled1
.
getVariables
().
find
(
"x"
)
!=
compiled1
.
getVariables
().
end
())
compiled1
.
getVariableReference
(
"x"
)
=
x
;
if
(
compiled1
.
getVariables
().
find
(
"y"
)
!=
compiled1
.
getVariables
().
end
())
compiled1
.
getVariableReference
(
"y"
)
=
y
;
if
(
compiled2
.
getVariables
().
find
(
"x"
)
!=
compiled2
.
getVariables
().
end
())
compiled2
.
getVariableReference
(
"x"
)
=
x
;
if
(
compiled2
.
getVariables
().
find
(
"y"
)
!=
compiled2
.
getVariables
().
end
())
compiled2
.
getVariableReference
(
"y"
)
=
y
;
assertNumbersEqual
(
val1
,
compiled1
.
evaluate
());
assertNumbersEqual
(
val2
,
compiled2
.
evaluate
());
}
}
/**
/**
...
@@ -171,14 +194,14 @@ void testCustomFunction(const string& expression, const string& equivalent) {
...
@@ -171,14 +194,14 @@ void testCustomFunction(const string& expression, const string& equivalent) {
verifySameValue
(
exp1
,
exp2
,
2.0
,
3.0
);
verifySameValue
(
exp1
,
exp2
,
2.0
,
3.0
);
verifySameValue
(
exp1
,
exp2
,
-
2.0
,
3.0
);
verifySameValue
(
exp1
,
exp2
,
-
2.0
,
3.0
);
verifySameValue
(
exp1
,
exp2
,
2.0
,
-
3.0
);
verifySameValue
(
exp1
,
exp2
,
2.0
,
-
3.0
);
ParsedExpression
deriv1
=
exp1
.
differentiate
(
"x"
);
ParsedExpression
deriv1
=
exp1
.
differentiate
(
"x"
)
.
optimize
()
;
ParsedExpression
deriv2
=
exp2
.
differentiate
(
"x"
);
ParsedExpression
deriv2
=
exp2
.
differentiate
(
"x"
)
.
optimize
()
;
verifySameValue
(
deriv1
,
deriv2
,
1.0
,
2.0
);
verifySameValue
(
deriv1
,
deriv2
,
1.0
,
2.0
);
verifySameValue
(
deriv1
,
deriv2
,
2.0
,
3.0
);
verifySameValue
(
deriv1
,
deriv2
,
2.0
,
3.0
);
verifySameValue
(
deriv1
,
deriv2
,
-
2.0
,
3.0
);
verifySameValue
(
deriv1
,
deriv2
,
-
2.0
,
3.0
);
verifySameValue
(
deriv1
,
deriv2
,
2.0
,
-
3.0
);
verifySameValue
(
deriv1
,
deriv2
,
2.0
,
-
3.0
);
ParsedExpression
deriv3
=
deriv1
.
differentiate
(
"y"
);
ParsedExpression
deriv3
=
deriv1
.
differentiate
(
"y"
)
.
optimize
()
;
ParsedExpression
deriv4
=
deriv2
.
differentiate
(
"y"
);
ParsedExpression
deriv4
=
deriv2
.
differentiate
(
"y"
)
.
optimize
()
;
verifySameValue
(
deriv3
,
deriv4
,
1.0
,
2.0
);
verifySameValue
(
deriv3
,
deriv4
,
1.0
,
2.0
);
verifySameValue
(
deriv3
,
deriv4
,
2.0
,
3.0
);
verifySameValue
(
deriv3
,
deriv4
,
2.0
,
3.0
);
verifySameValue
(
deriv3
,
deriv4
,
-
2.0
,
3.0
);
verifySameValue
(
deriv3
,
deriv4
,
-
2.0
,
3.0
);
...
@@ -223,6 +246,7 @@ int main() {
...
@@ -223,6 +246,7 @@ int main() {
verifyEvaluation
(
"max(x, -1)"
,
2.0
,
3.0
,
2.0
);
verifyEvaluation
(
"max(x, -1)"
,
2.0
,
3.0
,
2.0
);
verifyEvaluation
(
"abs(x-y)"
,
2.0
,
3.0
,
1.0
);
verifyEvaluation
(
"abs(x-y)"
,
2.0
,
3.0
,
1.0
);
verifyEvaluation
(
"delta(x)+3*delta(y-1.5)"
,
2.0
,
1.5
,
3.0
);
verifyEvaluation
(
"delta(x)+3*delta(y-1.5)"
,
2.0
,
1.5
,
3.0
);
verifyEvaluation
(
"step(x-3)+y*step(x)"
,
2.0
,
3.0
,
3.0
);
verifyInvalidExpression
(
"1..2"
);
verifyInvalidExpression
(
"1..2"
);
verifyInvalidExpression
(
"1*(2+3"
);
verifyInvalidExpression
(
"1*(2+3"
);
verifyInvalidExpression
(
"5++4"
);
verifyInvalidExpression
(
"5++4"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment