Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
d0ce27f1
Commit
d0ce27f1
authored
Nov 12, 2009
by
Peter Eastman
Browse files
Optimization to evaluating tabulated functions
parent
e1a7bc3d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
9 deletions
+65
-9
platforms/opencl/src/OpenCLExpressionUtilities.cpp
platforms/opencl/src/OpenCLExpressionUtilities.cpp
+52
-8
platforms/opencl/src/OpenCLExpressionUtilities.h
platforms/opencl/src/OpenCLExpressionUtilities.h
+13
-1
No files found.
platforms/opencl/src/OpenCLExpressionUtilities.cpp
View file @
d0ce27f1
...
@@ -48,22 +48,27 @@ static string intToString(int value) {
...
@@ -48,22 +48,27 @@ static string intToString(int value) {
string
OpenCLExpressionUtilities
::
createExpressions
(
const
map
<
string
,
ParsedExpression
>&
expressions
,
const
map
<
string
,
string
>&
variables
,
string
OpenCLExpressionUtilities
::
createExpressions
(
const
map
<
string
,
ParsedExpression
>&
expressions
,
const
map
<
string
,
string
>&
variables
,
const
vector
<
pair
<
string
,
string
>
>&
functions
,
const
string
&
prefix
,
const
string
&
functionParams
)
{
const
vector
<
pair
<
string
,
string
>
>&
functions
,
const
string
&
prefix
,
const
string
&
functionParams
)
{
stringstream
out
;
stringstream
out
;
vector
<
ParsedExpression
>
allExpressions
;
for
(
map
<
string
,
ParsedExpression
>::
const_iterator
iter
=
expressions
.
begin
();
iter
!=
expressions
.
end
();
++
iter
)
allExpressions
.
push_back
(
iter
->
second
);
vector
<
pair
<
ExpressionTreeNode
,
string
>
>
temps
;
vector
<
pair
<
ExpressionTreeNode
,
string
>
>
temps
;
for
(
map
<
string
,
ParsedExpression
>::
const_iterator
iter
=
expressions
.
begin
();
iter
!=
expressions
.
end
();
++
iter
)
{
for
(
map
<
string
,
ParsedExpression
>::
const_iterator
iter
=
expressions
.
begin
();
iter
!=
expressions
.
end
();
++
iter
)
{
processExpression
(
out
,
iter
->
second
.
getRootNode
(),
temps
,
variables
,
functions
,
prefix
,
functionParams
);
processExpression
(
out
,
iter
->
second
.
getRootNode
(),
temps
,
variables
,
functions
,
prefix
,
functionParams
,
allExpressions
);
out
<<
iter
->
first
<<
getTempName
(
iter
->
second
.
getRootNode
(),
temps
)
<<
";
\n
"
;
out
<<
iter
->
first
<<
getTempName
(
iter
->
second
.
getRootNode
(),
temps
)
<<
";
\n
"
;
}
}
return
out
.
str
();
return
out
.
str
();
}
}
void
OpenCLExpressionUtilities
::
processExpression
(
stringstream
&
out
,
const
ExpressionTreeNode
&
node
,
vector
<
pair
<
ExpressionTreeNode
,
string
>
>&
temps
,
void
OpenCLExpressionUtilities
::
processExpression
(
stringstream
&
out
,
const
ExpressionTreeNode
&
node
,
vector
<
pair
<
ExpressionTreeNode
,
string
>
>&
temps
,
const
map
<
string
,
string
>&
variables
,
const
vector
<
pair
<
string
,
string
>
>&
functions
,
const
string
&
prefix
,
const
string
&
functionParams
)
{
const
map
<
string
,
string
>&
variables
,
const
vector
<
pair
<
string
,
string
>
>&
functions
,
const
string
&
prefix
,
const
string
&
functionParams
,
const
vector
<
ParsedExpression
>&
allExpressions
)
{
for
(
int
i
=
0
;
i
<
(
int
)
temps
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
temps
.
size
();
i
++
)
if
(
temps
[
i
].
first
==
node
)
if
(
temps
[
i
].
first
==
node
)
return
;
return
;
for
(
int
i
=
0
;
i
<
(
int
)
node
.
getChildren
().
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
node
.
getChildren
().
size
();
i
++
)
processExpression
(
out
,
node
.
getChildren
()[
i
],
temps
,
variables
,
functions
,
prefix
,
functionParams
);
processExpression
(
out
,
node
.
getChildren
()[
i
],
temps
,
variables
,
functions
,
prefix
,
functionParams
,
allExpressions
);
string
name
=
prefix
+
intToString
(
temps
.
size
());
string
name
=
prefix
+
intToString
(
temps
.
size
());
bool
hasRecordedNode
=
false
;
out
<<
"float "
<<
name
<<
" = "
;
out
<<
"float "
<<
name
<<
" = "
;
switch
(
node
.
getOperation
().
getId
())
{
switch
(
node
.
getOperation
().
getId
())
{
...
@@ -85,7 +90,32 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -85,7 +90,32 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
;
;
if
(
i
==
functions
.
size
())
if
(
i
==
functions
.
size
())
throw
OpenMMException
(
"Unknown function in expression: "
+
node
.
getOperation
().
getName
());
throw
OpenMMException
(
"Unknown function in expression: "
+
node
.
getOperation
().
getName
());
bool
isDeriv
=
(
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
node
.
getOperation
())
->
getDerivOrder
()[
0
]
==
1
);
out
<<
"0.0f;
\n
"
;
out
<<
"0.0f;
\n
"
;
temps
.
push_back
(
make_pair
(
node
,
name
));
hasRecordedNode
=
true
;
// If both the value and derivative of the function are needed, it's faster to calculate them both
// at once, so check to see if both are needed.
const
ExpressionTreeNode
*
valueNode
=
NULL
;
const
ExpressionTreeNode
*
derivNode
=
NULL
;
for
(
int
j
=
0
;
j
<
(
int
)
allExpressions
.
size
();
j
++
)
findRelatedTabulatedFunctions
(
node
,
allExpressions
[
j
].
getRootNode
(),
valueNode
,
derivNode
);
string
valueName
=
name
;
string
derivName
=
name
;
if
(
valueNode
!=
NULL
&&
derivNode
!=
NULL
)
{
string
name2
=
prefix
+
intToString
(
temps
.
size
());
out
<<
"float "
<<
name2
<<
" = 0.0f;
\n
"
;
if
(
isDeriv
)
{
valueName
=
name2
;
temps
.
push_back
(
make_pair
(
*
valueNode
,
name2
));
}
else
{
derivName
=
name2
;
temps
.
push_back
(
make_pair
(
*
derivNode
,
name2
));
}
}
out
<<
"{
\n
"
;
out
<<
"{
\n
"
;
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
out
<<
"float x = "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
";
\n
"
;
out
<<
"float x = "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
";
\n
"
;
...
@@ -93,10 +123,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -93,10 +123,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out
<<
"int index = (int) (floor((x-params.x)*params.z));
\n
"
;
out
<<
"int index = (int) (floor((x-params.x)*params.z));
\n
"
;
out
<<
"float4 coeff = "
<<
functions
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"float4 coeff = "
<<
functions
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"x = (x-params.x)*params.z-index;
\n
"
;
out
<<
"x = (x-params.x)*params.z-index;
\n
"
;
if
(
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
node
.
getOperation
())
->
getDerivOrder
()[
0
]
==
0
)
if
(
valueNode
!=
NULL
)
out
<<
n
ame
<<
" = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));
\n
"
;
out
<<
valueN
ame
<<
" = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));
\n
"
;
else
if
(
derivNode
!=
NULL
)
out
<<
n
ame
<<
" = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;
\n
"
;
out
<<
derivN
ame
<<
" = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;
\n
"
;
out
<<
"}
\n
"
;
out
<<
"}
\n
"
;
out
<<
"}"
;
out
<<
"}"
;
break
;
break
;
...
@@ -218,6 +248,7 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -218,6 +248,7 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
throw
OpenMMException
(
"Internal error: Unknown operation in user-defined expression: "
+
node
.
getOperation
().
getName
());
throw
OpenMMException
(
"Internal error: Unknown operation in user-defined expression: "
+
node
.
getOperation
().
getName
());
}
}
out
<<
";
\n
"
;
out
<<
";
\n
"
;
if
(
!
hasRecordedNode
)
temps
.
push_back
(
make_pair
(
node
,
name
));
temps
.
push_back
(
make_pair
(
node
,
name
));
}
}
...
@@ -229,3 +260,16 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
...
@@ -229,3 +260,16 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
out
<<
"Internal error: No temporary variable for expression node: "
<<
node
;
out
<<
"Internal error: No temporary variable for expression node: "
<<
node
;
throw
OpenMMException
(
out
.
str
());
throw
OpenMMException
(
out
.
str
());
}
}
void
OpenCLExpressionUtilities
::
findRelatedTabulatedFunctions
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
const
ExpressionTreeNode
*&
valueNode
,
const
ExpressionTreeNode
*&
derivNode
)
{
if
(
searchNode
.
getOperation
().
getId
()
==
Operation
::
CUSTOM
&&
node
.
getChildren
()[
0
]
==
searchNode
.
getChildren
()[
0
])
{
if
(
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
searchNode
.
getOperation
())
->
getDerivOrder
()[
0
]
==
0
)
valueNode
=
&
searchNode
;
else
derivNode
=
&
searchNode
;
}
else
for
(
int
i
=
0
;
i
<
(
int
)
searchNode
.
getChildren
().
size
();
i
++
)
findRelatedTabulatedFunctions
(
node
,
searchNode
.
getChildren
()[
i
],
valueNode
,
derivNode
);
}
platforms/opencl/src/OpenCLExpressionUtilities.h
View file @
d0ce27f1
...
@@ -43,13 +43,25 @@ namespace OpenMM {
...
@@ -43,13 +43,25 @@ namespace OpenMM {
class
OpenCLExpressionUtilities
{
class
OpenCLExpressionUtilities
{
public:
public:
/**
* Generate the source code for calculating a set of expressions.
*
* @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable that may appear in the expressions
* @param functions defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
*/
static
std
::
string
createExpressions
(
const
std
::
map
<
std
::
string
,
Lepton
::
ParsedExpression
>&
expressions
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
,
static
std
::
string
createExpressions
(
const
std
::
map
<
std
::
string
,
Lepton
::
ParsedExpression
>&
expressions
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functions
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
);
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functions
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
);
private:
private:
static
void
processExpression
(
std
::
stringstream
&
out
,
const
Lepton
::
ExpressionTreeNode
&
node
,
static
void
processExpression
(
std
::
stringstream
&
out
,
const
Lepton
::
ExpressionTreeNode
&
node
,
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
,
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functions
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
);
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functions
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
,
const
std
::
vector
<
Lepton
::
ParsedExpression
>&
allExpressions
);
static
std
::
string
getTempName
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
);
static
std
::
string
getTempName
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
);
static
void
findRelatedTabulatedFunctions
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
const
Lepton
::
ExpressionTreeNode
*&
valueNode
,
const
Lepton
::
ExpressionTreeNode
*&
derivNode
);
};
};
}
// namespace OpenMM
}
// namespace OpenMM
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment