Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
d0ce27f1
"openmmapi/vscode:/vscode.git/clone" did not exist on "729c09907584fde040dc58c860181d25b133b68d"
Commit
d0ce27f1
authored
Nov 12, 2009
by
Peter Eastman
Browse files
Optimization to evaluating tabulated functions
parent
e1a7bc3d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
9 deletions
+65
-9
platforms/opencl/src/OpenCLExpressionUtilities.cpp
platforms/opencl/src/OpenCLExpressionUtilities.cpp
+52
-8
platforms/opencl/src/OpenCLExpressionUtilities.h
platforms/opencl/src/OpenCLExpressionUtilities.h
+13
-1
No files found.
platforms/opencl/src/OpenCLExpressionUtilities.cpp
View file @
d0ce27f1
...
@@ -48,22 +48,27 @@ static string intToString(int value) {
...
@@ -48,22 +48,27 @@ static string intToString(int value) {
string
OpenCLExpressionUtilities
::
createExpressions
(
const
map
<
string
,
ParsedExpression
>&
expressions
,
const
map
<
string
,
string
>&
variables
,
string
OpenCLExpressionUtilities
::
createExpressions
(
const
map
<
string
,
ParsedExpression
>&
expressions
,
const
map
<
string
,
string
>&
variables
,
const
vector
<
pair
<
string
,
string
>
>&
functions
,
const
string
&
prefix
,
const
string
&
functionParams
)
{
const
vector
<
pair
<
string
,
string
>
>&
functions
,
const
string
&
prefix
,
const
string
&
functionParams
)
{
stringstream
out
;
stringstream
out
;
vector
<
ParsedExpression
>
allExpressions
;
for
(
map
<
string
,
ParsedExpression
>::
const_iterator
iter
=
expressions
.
begin
();
iter
!=
expressions
.
end
();
++
iter
)
allExpressions
.
push_back
(
iter
->
second
);
vector
<
pair
<
ExpressionTreeNode
,
string
>
>
temps
;
vector
<
pair
<
ExpressionTreeNode
,
string
>
>
temps
;
for
(
map
<
string
,
ParsedExpression
>::
const_iterator
iter
=
expressions
.
begin
();
iter
!=
expressions
.
end
();
++
iter
)
{
for
(
map
<
string
,
ParsedExpression
>::
const_iterator
iter
=
expressions
.
begin
();
iter
!=
expressions
.
end
();
++
iter
)
{
processExpression
(
out
,
iter
->
second
.
getRootNode
(),
temps
,
variables
,
functions
,
prefix
,
functionParams
);
processExpression
(
out
,
iter
->
second
.
getRootNode
(),
temps
,
variables
,
functions
,
prefix
,
functionParams
,
allExpressions
);
out
<<
iter
->
first
<<
getTempName
(
iter
->
second
.
getRootNode
(),
temps
)
<<
";
\n
"
;
out
<<
iter
->
first
<<
getTempName
(
iter
->
second
.
getRootNode
(),
temps
)
<<
";
\n
"
;
}
}
return
out
.
str
();
return
out
.
str
();
}
}
void
OpenCLExpressionUtilities
::
processExpression
(
stringstream
&
out
,
const
ExpressionTreeNode
&
node
,
vector
<
pair
<
ExpressionTreeNode
,
string
>
>&
temps
,
void
OpenCLExpressionUtilities
::
processExpression
(
stringstream
&
out
,
const
ExpressionTreeNode
&
node
,
vector
<
pair
<
ExpressionTreeNode
,
string
>
>&
temps
,
const
map
<
string
,
string
>&
variables
,
const
vector
<
pair
<
string
,
string
>
>&
functions
,
const
string
&
prefix
,
const
string
&
functionParams
)
{
const
map
<
string
,
string
>&
variables
,
const
vector
<
pair
<
string
,
string
>
>&
functions
,
const
string
&
prefix
,
const
string
&
functionParams
,
const
vector
<
ParsedExpression
>&
allExpressions
)
{
for
(
int
i
=
0
;
i
<
(
int
)
temps
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
temps
.
size
();
i
++
)
if
(
temps
[
i
].
first
==
node
)
if
(
temps
[
i
].
first
==
node
)
return
;
return
;
for
(
int
i
=
0
;
i
<
(
int
)
node
.
getChildren
().
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
node
.
getChildren
().
size
();
i
++
)
processExpression
(
out
,
node
.
getChildren
()[
i
],
temps
,
variables
,
functions
,
prefix
,
functionParams
);
processExpression
(
out
,
node
.
getChildren
()[
i
],
temps
,
variables
,
functions
,
prefix
,
functionParams
,
allExpressions
);
string
name
=
prefix
+
intToString
(
temps
.
size
());
string
name
=
prefix
+
intToString
(
temps
.
size
());
bool
hasRecordedNode
=
false
;
out
<<
"float "
<<
name
<<
" = "
;
out
<<
"float "
<<
name
<<
" = "
;
switch
(
node
.
getOperation
().
getId
())
{
switch
(
node
.
getOperation
().
getId
())
{
...
@@ -85,7 +90,32 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -85,7 +90,32 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
;
;
if
(
i
==
functions
.
size
())
if
(
i
==
functions
.
size
())
throw
OpenMMException
(
"Unknown function in expression: "
+
node
.
getOperation
().
getName
());
throw
OpenMMException
(
"Unknown function in expression: "
+
node
.
getOperation
().
getName
());
bool
isDeriv
=
(
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
node
.
getOperation
())
->
getDerivOrder
()[
0
]
==
1
);
out
<<
"0.0f;
\n
"
;
out
<<
"0.0f;
\n
"
;
temps
.
push_back
(
make_pair
(
node
,
name
));
hasRecordedNode
=
true
;
// If both the value and derivative of the function are needed, it's faster to calculate them both
// at once, so check to see if both are needed.
const
ExpressionTreeNode
*
valueNode
=
NULL
;
const
ExpressionTreeNode
*
derivNode
=
NULL
;
for
(
int
j
=
0
;
j
<
(
int
)
allExpressions
.
size
();
j
++
)
findRelatedTabulatedFunctions
(
node
,
allExpressions
[
j
].
getRootNode
(),
valueNode
,
derivNode
);
string
valueName
=
name
;
string
derivName
=
name
;
if
(
valueNode
!=
NULL
&&
derivNode
!=
NULL
)
{
string
name2
=
prefix
+
intToString
(
temps
.
size
());
out
<<
"float "
<<
name2
<<
" = 0.0f;
\n
"
;
if
(
isDeriv
)
{
valueName
=
name2
;
temps
.
push_back
(
make_pair
(
*
valueNode
,
name2
));
}
else
{
derivName
=
name2
;
temps
.
push_back
(
make_pair
(
*
derivNode
,
name2
));
}
}
out
<<
"{
\n
"
;
out
<<
"{
\n
"
;
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
out
<<
"float x = "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
";
\n
"
;
out
<<
"float x = "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
";
\n
"
;
...
@@ -93,10 +123,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -93,10 +123,10 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
out
<<
"int index = (int) (floor((x-params.x)*params.z));
\n
"
;
out
<<
"int index = (int) (floor((x-params.x)*params.z));
\n
"
;
out
<<
"float4 coeff = "
<<
functions
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"float4 coeff = "
<<
functions
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"x = (x-params.x)*params.z-index;
\n
"
;
out
<<
"x = (x-params.x)*params.z-index;
\n
"
;
if
(
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
node
.
getOperation
())
->
getDerivOrder
()[
0
]
==
0
)
if
(
valueNode
!=
NULL
)
out
<<
n
ame
<<
" = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));
\n
"
;
out
<<
valueN
ame
<<
" = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));
\n
"
;
else
if
(
derivNode
!=
NULL
)
out
<<
n
ame
<<
" = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;
\n
"
;
out
<<
derivN
ame
<<
" = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;
\n
"
;
out
<<
"}
\n
"
;
out
<<
"}
\n
"
;
out
<<
"}"
;
out
<<
"}"
;
break
;
break
;
...
@@ -218,7 +248,8 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
...
@@ -218,7 +248,8 @@ void OpenCLExpressionUtilities::processExpression(stringstream& out, const Expre
throw
OpenMMException
(
"Internal error: Unknown operation in user-defined expression: "
+
node
.
getOperation
().
getName
());
throw
OpenMMException
(
"Internal error: Unknown operation in user-defined expression: "
+
node
.
getOperation
().
getName
());
}
}
out
<<
";
\n
"
;
out
<<
";
\n
"
;
temps
.
push_back
(
make_pair
(
node
,
name
));
if
(
!
hasRecordedNode
)
temps
.
push_back
(
make_pair
(
node
,
name
));
}
}
string
OpenCLExpressionUtilities
::
getTempName
(
const
ExpressionTreeNode
&
node
,
const
vector
<
pair
<
ExpressionTreeNode
,
string
>
>&
temps
)
{
string
OpenCLExpressionUtilities
::
getTempName
(
const
ExpressionTreeNode
&
node
,
const
vector
<
pair
<
ExpressionTreeNode
,
string
>
>&
temps
)
{
...
@@ -229,3 +260,16 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
...
@@ -229,3 +260,16 @@ string OpenCLExpressionUtilities::getTempName(const ExpressionTreeNode& node, co
out
<<
"Internal error: No temporary variable for expression node: "
<<
node
;
out
<<
"Internal error: No temporary variable for expression node: "
<<
node
;
throw
OpenMMException
(
out
.
str
());
throw
OpenMMException
(
out
.
str
());
}
}
void
OpenCLExpressionUtilities
::
findRelatedTabulatedFunctions
(
const
ExpressionTreeNode
&
node
,
const
ExpressionTreeNode
&
searchNode
,
const
ExpressionTreeNode
*&
valueNode
,
const
ExpressionTreeNode
*&
derivNode
)
{
if
(
searchNode
.
getOperation
().
getId
()
==
Operation
::
CUSTOM
&&
node
.
getChildren
()[
0
]
==
searchNode
.
getChildren
()[
0
])
{
if
(
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
searchNode
.
getOperation
())
->
getDerivOrder
()[
0
]
==
0
)
valueNode
=
&
searchNode
;
else
derivNode
=
&
searchNode
;
}
else
for
(
int
i
=
0
;
i
<
(
int
)
searchNode
.
getChildren
().
size
();
i
++
)
findRelatedTabulatedFunctions
(
node
,
searchNode
.
getChildren
()[
i
],
valueNode
,
derivNode
);
}
platforms/opencl/src/OpenCLExpressionUtilities.h
View file @
d0ce27f1
...
@@ -43,13 +43,25 @@ namespace OpenMM {
...
@@ -43,13 +43,25 @@ namespace OpenMM {
class
OpenCLExpressionUtilities
{
class
OpenCLExpressionUtilities
{
public:
public:
/**
* Generate the source code for calculating a set of expressions.
*
* @param expressions the expressions to generate code for (keys are the variables to store the output values in)
* @param variables defines the source code to generate for each variable that may appear in the expressions
* @param functions defines the variable name for each tabulated function that may appear in the expressions
* @param prefix a prefix to put in front of temporary variables
* @param functionParams the variable name containing the parameters for each tabulated function
*/
static
std
::
string
createExpressions
(
const
std
::
map
<
std
::
string
,
Lepton
::
ParsedExpression
>&
expressions
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
,
static
std
::
string
createExpressions
(
const
std
::
map
<
std
::
string
,
Lepton
::
ParsedExpression
>&
expressions
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functions
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
);
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functions
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
);
private:
private:
static
void
processExpression
(
std
::
stringstream
&
out
,
const
Lepton
::
ExpressionTreeNode
&
node
,
static
void
processExpression
(
std
::
stringstream
&
out
,
const
Lepton
::
ExpressionTreeNode
&
node
,
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
,
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functions
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
);
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functions
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
,
const
std
::
vector
<
Lepton
::
ParsedExpression
>&
allExpressions
);
static
std
::
string
getTempName
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
);
static
std
::
string
getTempName
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
);
static
void
findRelatedTabulatedFunctions
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
Lepton
::
ExpressionTreeNode
&
searchNode
,
const
Lepton
::
ExpressionTreeNode
*&
valueNode
,
const
Lepton
::
ExpressionTreeNode
*&
derivNode
);
};
};
}
// namespace OpenMM
}
// namespace OpenMM
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment