Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
72d59cbe
Commit
72d59cbe
authored
Nov 02, 2009
by
Peter Eastman
Browse files
Finished OpenCL implementation of CustomNonbondedForce. Also implemented a few optimizations.
parent
2127b8dd
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
366 additions
and
123 deletions
+366
-123
libraries/lepton/include/lepton/ExpressionTreeNode.h
libraries/lepton/include/lepton/ExpressionTreeNode.h
+2
-0
libraries/lepton/include/lepton/Operation.h
libraries/lepton/include/lepton/Operation.h
+30
-0
libraries/lepton/include/lepton/ParsedExpression.h
libraries/lepton/include/lepton/ParsedExpression.h
+5
-0
libraries/lepton/src/ExpressionTreeNode.cpp
libraries/lepton/src/ExpressionTreeNode.cpp
+13
-0
libraries/lepton/src/ParsedExpression.cpp
libraries/lepton/src/ParsedExpression.cpp
+5
-0
platforms/opencl/src/OpenCLExpressionUtilities.cpp
platforms/opencl/src/OpenCLExpressionUtilities.cpp
+113
-31
platforms/opencl/src/OpenCLExpressionUtilities.h
platforms/opencl/src/OpenCLExpressionUtilities.h
+8
-2
platforms/opencl/src/OpenCLKernels.cpp
platforms/opencl/src/OpenCLKernels.cpp
+163
-69
platforms/opencl/src/OpenCLKernels.h
platforms/opencl/src/OpenCLKernels.h
+19
-9
platforms/opencl/src/kernels/customNonbondedExceptions.cl
platforms/opencl/src/kernels/customNonbondedExceptions.cl
+6
-10
platforms/opencl/tests/TestOpenCLCustomNonbondedForce.cpp
platforms/opencl/tests/TestOpenCLCustomNonbondedForce.cpp
+2
-2
No files found.
libraries/lepton/include/lepton/ExpressionTreeNode.h
View file @
72d59cbe
...
...
@@ -84,6 +84,8 @@ public:
ExpressionTreeNode
(
const
ExpressionTreeNode
&
node
);
ExpressionTreeNode
();
~
ExpressionTreeNode
();
bool
operator
==
(
const
ExpressionTreeNode
&
node
)
const
;
bool
operator
!=
(
const
ExpressionTreeNode
&
node
)
const
;
ExpressionTreeNode
&
operator
=
(
const
ExpressionTreeNode
&
node
);
/**
* Get the Operation performed by this node.
...
...
libraries/lepton/include/lepton/Operation.h
View file @
72d59cbe
...
...
@@ -95,6 +95,12 @@ public:
* @param variable the variable with respect to which the derivate should be taken
*/
virtual
ExpressionTreeNode
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
=
0
;
virtual
bool
operator
!=
(
const
Operation
&
op
)
const
{
return
op
.
getId
()
!=
getId
();
}
virtual
bool
operator
==
(
const
Operation
&
op
)
const
{
return
!
(
*
this
!=
op
);
}
class
Constant
;
class
Variable
;
class
Custom
;
...
...
@@ -149,6 +155,10 @@ public:
double
getValue
()
const
{
return
value
;
}
bool
operator
!=
(
const
Operation
&
op
)
const
{
const
Constant
*
o
=
dynamic_cast
<
const
Constant
*>
(
&
op
);
return
(
o
==
NULL
||
o
->
value
!=
value
);
}
private:
double
value
;
};
...
...
@@ -176,6 +186,10 @@ public:
return
iter
->
second
;
}
ExpressionTreeNode
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
;
bool
operator
!=
(
const
Operation
&
op
)
const
{
const
Variable
*
o
=
dynamic_cast
<
const
Variable
*>
(
&
op
);
return
(
o
==
NULL
||
o
->
name
!=
name
);
}
private:
std
::
string
name
;
};
...
...
@@ -214,6 +228,10 @@ public:
const
std
::
vector
<
int
>&
getDerivOrder
()
const
{
return
derivOrder
;
}
bool
operator
!=
(
const
Operation
&
op
)
const
{
const
Custom
*
o
=
dynamic_cast
<
const
Custom
*>
(
&
op
);
return
(
o
==
NULL
||
o
->
name
!=
name
||
o
->
isDerivative
!=
isDerivative
||
o
->
derivOrder
!=
derivOrder
);
}
private:
std
::
string
name
;
CustomFunction
*
function
;
...
...
@@ -708,6 +726,10 @@ public:
double
getValue
()
const
{
return
value
;
}
bool
operator
!=
(
const
Operation
&
op
)
const
{
const
AddConstant
*
o
=
dynamic_cast
<
const
AddConstant
*>
(
&
op
);
return
(
o
==
NULL
||
o
->
value
!=
value
);
}
private:
double
value
;
};
...
...
@@ -737,6 +759,10 @@ public:
double
getValue
()
const
{
return
value
;
}
bool
operator
!=
(
const
Operation
&
op
)
const
{
const
MultiplyConstant
*
o
=
dynamic_cast
<
const
MultiplyConstant
*>
(
&
op
);
return
(
o
==
NULL
||
o
->
value
!=
value
);
}
private:
double
value
;
};
...
...
@@ -766,6 +792,10 @@ public:
double
getValue
()
const
{
return
value
;
}
bool
operator
!=
(
const
Operation
&
op
)
const
{
const
PowerConstant
*
o
=
dynamic_cast
<
const
PowerConstant
*>
(
&
op
);
return
(
o
==
NULL
||
o
->
value
!=
value
);
}
private:
double
value
;
};
...
...
libraries/lepton/include/lepton/ParsedExpression.h
View file @
72d59cbe
...
...
@@ -48,6 +48,11 @@ class ExpressionProgram;
class
LEPTON_EXPORT
ParsedExpression
{
public:
/**
* Create an uninitialized ParsedExpression. This exists so that ParsedExpressions can be put in STL containers.
* Doing anything with it will produce an exception.
*/
ParsedExpression
();
/**
* Create a ParsedExpression. Normally you will not call this directly. Instead, use the Parser class
* to parse expression.
...
...
libraries/lepton/src/ExpressionTreeNode.cpp
View file @
72d59cbe
...
...
@@ -70,6 +70,19 @@ ExpressionTreeNode::~ExpressionTreeNode() {
delete
operation
;
}
bool
ExpressionTreeNode
::
operator
!=
(
const
ExpressionTreeNode
&
node
)
const
{
if
(
node
.
getOperation
()
!=
getOperation
())
return
true
;
for
(
int
i
=
0
;
i
<
(
int
)
getChildren
().
size
();
i
++
)
if
(
getChildren
()[
i
]
!=
node
.
getChildren
()[
i
])
return
true
;
return
false
;
}
bool
ExpressionTreeNode
::
operator
==
(
const
ExpressionTreeNode
&
node
)
const
{
return
!
(
*
this
!=
node
);
}
ExpressionTreeNode
&
ExpressionTreeNode
::
operator
=
(
const
ExpressionTreeNode
&
node
)
{
if
(
operation
!=
NULL
)
delete
operation
;
...
...
libraries/lepton/src/ParsedExpression.cpp
View file @
72d59cbe
...
...
@@ -38,10 +38,15 @@
using
namespace
Lepton
;
using
namespace
std
;
ParsedExpression
::
ParsedExpression
()
:
rootNode
(
ExpressionTreeNode
())
{
}
ParsedExpression
::
ParsedExpression
(
const
ExpressionTreeNode
&
rootNode
)
:
rootNode
(
rootNode
)
{
}
const
ExpressionTreeNode
&
ParsedExpression
::
getRootNode
()
const
{
if
(
&
rootNode
.
getOperation
()
==
NULL
)
throw
Exception
(
"Illegal call to an initialized ParsedExpression"
);
return
rootNode
;
}
...
...
platforms/opencl/src/OpenCLExpressionUtilities.cpp
View file @
72d59cbe
...
...
@@ -27,7 +27,6 @@
#include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h"
#include "lepton/Operation.h"
#include <sstream>
using
namespace
OpenMM
;
using
namespace
Lepton
;
...
...
@@ -46,69 +45,152 @@ static string intToString(int value) {
return
s
.
str
();
}
string
OpenCLExpressionUtilities
::
createExpression
(
const
ParsedExpression
&
expression
,
const
map
<
string
,
string
>&
variables
)
{
return
processExpression
(
expression
.
getRootNode
(),
variables
);
string
OpenCLExpressionUtilities
::
createExpressions
(
const
map
<
string
,
ParsedExpression
>&
expressions
,
const
map
<
string
,
string
>&
variables
,
const
vector
<
pair
<
string
,
string
>
>&
functions
,
const
string
&
prefix
,
const
string
&
functionParams
)
{
stringstream
out
;
vector
<
pair
<
ExpressionTreeNode
,
string
>
>
temps
;
for
(
map
<
string
,
ParsedExpression
>::
const_iterator
iter
=
expressions
.
begin
();
iter
!=
expressions
.
end
();
++
iter
)
{
processExpression
(
out
,
iter
->
second
.
getRootNode
(),
temps
,
variables
,
functions
,
prefix
,
functionParams
);
out
<<
iter
->
first
<<
getTempName
(
iter
->
second
.
getRootNode
(),
temps
)
<<
";
\n
"
;
}
return
out
.
str
();
}
string
OpenCLExpressionUtilities
::
processExpression
(
const
ExpressionTreeNode
&
node
,
const
map
<
string
,
string
>&
variables
)
{
void
OpenCLExpressionUtilities
::
processExpression
(
stringstream
&
out
,
const
ExpressionTreeNode
&
node
,
vector
<
pair
<
ExpressionTreeNode
,
string
>
>&
temps
,
const
map
<
string
,
string
>&
variables
,
const
vector
<
pair
<
string
,
string
>
>&
functions
,
const
string
&
prefix
,
const
string
&
functionParams
)
{
for
(
int
i
=
0
;
i
<
(
int
)
temps
.
size
();
i
++
)
if
(
temps
[
i
].
first
==
node
)
return
;
for
(
int
i
=
0
;
i
<
(
int
)
node
.
getChildren
().
size
();
i
++
)
processExpression
(
out
,
node
.
getChildren
()[
i
],
temps
,
variables
,
functions
,
prefix
,
functionParams
);
string
name
=
prefix
+
intToString
(
temps
.
size
());
out
<<
"float "
<<
name
<<
" = "
;
switch
(
node
.
getOperation
().
getId
())
{
case
Operation
::
CONSTANT
:
return
doubleToString
(
dynamic_cast
<
const
Operation
::
Constant
*>
(
&
node
.
getOperation
())
->
getValue
());
out
<<
doubleToString
(
dynamic_cast
<
const
Operation
::
Constant
*>
(
&
node
.
getOperation
())
->
getValue
());
break
;
case
Operation
::
VARIABLE
:
{
map
<
string
,
string
>::
const_iterator
iter
=
variables
.
find
(
node
.
getOperation
().
getName
());
if
(
iter
==
variables
.
end
())
throw
OpenMMException
(
"Unknown variable in expression: "
+
node
.
getOperation
().
getName
());
return
iter
->
second
;
out
<<
iter
->
second
;
break
;
}
case
Operation
::
CUSTOM
:
{
int
i
;
for
(
i
=
0
;
i
<
(
int
)
functions
.
size
()
&&
functions
[
i
].
first
!=
node
.
getOperation
().
getName
();
i
++
)
;
if
(
i
==
functions
.
size
())
throw
OpenMMException
(
"Unknown function in expression: "
+
node
.
getOperation
().
getName
());
out
<<
"0.0f;
\n
"
;
out
<<
"{
\n
"
;
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
out
<<
"float x = "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
";
\n
"
;
out
<<
"if (x >= params.x && x <= params.y) {
\n
"
;
out
<<
"int index = (int) (floor((x-params.x)*params.z));
\n
"
;
out
<<
"float4 coeff = "
<<
functions
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"x = (x-params.x)*params.z-index;
\n
"
;
if
(
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
node
.
getOperation
())
->
getDerivOrder
()[
0
]
==
0
)
out
<<
name
<<
" = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));
\n
"
;
else
out
<<
name
<<
" = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;
\n
"
;
out
<<
"}
\n
"
;
out
<<
"}"
;
break
;
}
case
Operation
::
ADD
:
return
"("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")+("
+
processExpression
(
node
.
getChildren
()[
1
],
variables
)
+
")"
;
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
"+"
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
);
break
;
case
Operation
::
SUBTRACT
:
return
"("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")-("
+
processExpression
(
node
.
getChildren
()[
1
],
variables
)
+
")"
;
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
"-"
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
);
break
;
case
Operation
::
MULTIPLY
:
return
"("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")*("
+
processExpression
(
node
.
getChildren
()[
1
],
variables
)
+
")"
;
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
"*"
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
);
break
;
case
Operation
::
DIVIDE
:
return
"("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")/("
+
processExpression
(
node
.
getChildren
()[
1
],
variables
)
+
")"
;
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
"/"
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
);
break
;
case
Operation
::
POWER
:
return
"pow(("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
"), ("
+
processExpression
(
node
.
getChildren
()[
1
],
variables
)
+
"))"
;
out
<<
"pow("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
", "
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
)
<<
")"
;
break
;
case
Operation
::
NEGATE
:
return
"-("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"-"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
break
;
case
Operation
::
SQRT
:
return
"sqrt("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"sqrt("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
EXP
:
return
"exp("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"exp("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
LOG
:
return
"log("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"log("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
SIN
:
return
"sin("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"sin("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
COS
:
return
"cos("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"cos("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
SEC
:
return
"1.0f/cos("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"1.0f/cos("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
CSC
:
return
"1.0f/sin("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"1.0f/sin("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
TAN
:
return
"tan("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"tan("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
COT
:
return
"1.0f/tan("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"1.0f/tan("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
ASIN
:
return
"asin("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"asin("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
ACOS
:
return
"acos("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"acos("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
ATAN
:
return
"atan("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"atan("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
SQUARE
:
return
"pow(("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
"), 2.0f)"
;
{
string
arg
=
getTempName
(
node
.
getChildren
()[
0
],
temps
);
out
<<
arg
<<
"*"
<<
arg
;
break
;
}
case
Operation
::
CUBE
:
return
"pow(("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
"), 3.0f)"
;
{
string
arg
=
getTempName
(
node
.
getChildren
()[
0
],
temps
);
out
<<
arg
<<
"*"
<<
arg
<<
"*"
<<
arg
;
break
;
}
case
Operation
::
RECIPROCAL
:
return
"1.0f/("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"1.0f/"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
break
;
case
Operation
::
ADD_CONSTANT
:
return
doubleToString
(
dynamic_cast
<
const
Operation
::
AddConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
+
"+("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
doubleToString
(
dynamic_cast
<
const
Operation
::
AddConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
<<
"+"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
break
;
case
Operation
::
MULTIPLY_CONSTANT
:
return
doubleToString
(
dynamic_cast
<
const
Operation
::
MultiplyConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
+
"*("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
doubleToString
(
dynamic_cast
<
const
Operation
::
MultiplyConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
<<
"*"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
break
;
case
Operation
::
POWER_CONSTANT
:
return
"pow(("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
"), "
+
doubleToString
(
dynamic_cast
<
const
Operation
::
PowerConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
+
")"
;
}
out
<<
"pow("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
", "
<<
doubleToString
(
dynamic_cast
<
const
Operation
::
PowerConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
<<
")"
;
break
;
default:
throw
OpenMMException
(
"Internal error: Unknown operation in user-defined expression: "
+
node
.
getOperation
().
getName
());
}
out
<<
";
\n
"
;
temps
.
push_back
(
make_pair
(
node
,
name
));
}
string
OpenCLExpressionUtilities
::
getTempName
(
const
ExpressionTreeNode
&
node
,
const
vector
<
pair
<
ExpressionTreeNode
,
string
>
>&
temps
)
{
for
(
int
i
=
0
;
i
<
(
int
)
temps
.
size
();
i
++
)
if
(
temps
[
i
].
first
==
node
)
return
temps
[
i
].
second
;
stringstream
out
;
out
<<
"Internal error: No temporary variable for expression node: "
<<
node
;
throw
OpenMMException
(
out
.
str
());
}
platforms/opencl/src/OpenCLExpressionUtilities.h
View file @
72d59cbe
...
...
@@ -30,7 +30,9 @@
#include "lepton/ExpressionTreeNode.h"
#include "lepton/ParsedExpression.h"
#include <map>
#include <sstream>
#include <string>
#include <utility>
namespace
OpenMM
{
...
...
@@ -41,9 +43,13 @@ namespace OpenMM {
class
OpenCLExpressionUtilities
{
public:
static
std
::
string
createExpression
(
const
Lepton
::
ParsedExpression
&
expression
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
);
static
std
::
string
createExpressions
(
const
std
::
map
<
std
::
string
,
Lepton
::
ParsedExpression
>&
expressions
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functions
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
);
private:
static
std
::
string
processExpression
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
);
static
void
processExpression
(
std
::
stringstream
&
out
,
const
Lepton
::
ExpressionTreeNode
&
node
,
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functions
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
);
static
std
::
string
getTempName
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
);
};
}
// namespace OpenMM
...
...
platforms/opencl/src/OpenCLKernels.cpp
View file @
72d59cbe
...
...
@@ -33,6 +33,7 @@
#include "OpenCLExpressionUtilities.h"
#include "OpenCLIntegrationUtilities.h"
#include "OpenCLNonbondedUtilities.h"
#include "lepton/CustomFunction.h"
#include "lepton/Parser.h"
#include "lepton/ParsedExpression.h"
#include <cmath>
...
...
@@ -231,6 +232,8 @@ void OpenCLCalcHarmonicBondForceKernel::initialize(const System& system, const H
}
void
OpenCLCalcHarmonicBondForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
if
(
!
hasInitializedKernel
)
{
hasInitializedKernel
=
true
;
kernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
kernel
.
setArg
<
cl_int
>
(
1
,
numBonds
);
kernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
...
...
@@ -238,6 +241,7 @@ void OpenCLCalcHarmonicBondForceKernel::executeForces(ContextImpl& context) {
kernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
5
,
params
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
6
,
indices
->
getDeviceBuffer
());
}
cl
.
executeKernel
(
kernel
,
numBonds
);
}
...
...
@@ -307,6 +311,8 @@ void OpenCLCalcHarmonicAngleForceKernel::initialize(const System& system, const
}
void
OpenCLCalcHarmonicAngleForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
if
(
!
hasInitializedKernel
)
{
hasInitializedKernel
=
true
;
kernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
kernel
.
setArg
<
cl_int
>
(
1
,
numAngles
);
kernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
...
...
@@ -314,6 +320,7 @@ void OpenCLCalcHarmonicAngleForceKernel::executeForces(ContextImpl& context) {
kernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
5
,
params
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
6
,
indices
->
getDeviceBuffer
());
}
cl
.
executeKernel
(
kernel
,
numAngles
);
}
...
...
@@ -384,6 +391,8 @@ void OpenCLCalcPeriodicTorsionForceKernel::initialize(const System& system, cons
}
void
OpenCLCalcPeriodicTorsionForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
if
(
!
hasInitializedKernel
)
{
hasInitializedKernel
=
true
;
kernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
kernel
.
setArg
<
cl_int
>
(
1
,
numTorsions
);
kernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
...
...
@@ -391,6 +400,7 @@ void OpenCLCalcPeriodicTorsionForceKernel::executeForces(ContextImpl& context) {
kernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
5
,
params
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
6
,
indices
->
getDeviceBuffer
());
}
cl
.
executeKernel
(
kernel
,
numTorsions
);
}
...
...
@@ -461,6 +471,8 @@ void OpenCLCalcRBTorsionForceKernel::initialize(const System& system, const RBTo
}
void
OpenCLCalcRBTorsionForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
if
(
!
hasInitializedKernel
)
{
hasInitializedKernel
=
true
;
kernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
kernel
.
setArg
<
cl_int
>
(
1
,
numTorsions
);
kernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
...
...
@@ -468,6 +480,7 @@ void OpenCLCalcRBTorsionForceKernel::executeForces(ContextImpl& context) {
kernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
5
,
params
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
6
,
indices
->
getDeviceBuffer
());
}
cl
.
executeKernel
(
kernel
,
numTorsions
);
}
...
...
@@ -639,6 +652,8 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
}
void
OpenCLCalcNonbondedForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
if
(
!
hasInitializedKernel
)
{
hasInitializedKernel
=
true
;
if
(
exceptionIndices
!=
NULL
)
{
int
numExceptions
=
exceptionIndices
->
getSize
();
exceptionsKernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
...
...
@@ -650,16 +665,20 @@ void OpenCLCalcNonbondedForceKernel::executeForces(ContextImpl& context) {
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
cl
.
getPosq
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
exceptionParams
->
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
8
,
exceptionIndices
->
getDeviceBuffer
());
cl
.
executeKernel
(
exceptionsKernel
,
numExceptions
);
}
if
(
cosSinSums
!=
NULL
)
{
ewaldSumsKernel
.
setArg
<
cl
::
Buffer
>
(
0
,
cl
.
getEnergyBuffer
().
getDeviceBuffer
());
ewaldSumsKernel
.
setArg
<
cl
::
Buffer
>
(
1
,
cl
.
getPosq
().
getDeviceBuffer
());
ewaldSumsKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cosSinSums
->
getDeviceBuffer
());
cl
.
executeKernel
(
ewaldSumsKernel
,
cosSinSums
->
getSize
());
ewaldForcesKernel
.
setArg
<
cl
::
Buffer
>
(
0
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
ewaldForcesKernel
.
setArg
<
cl
::
Buffer
>
(
1
,
cl
.
getPosq
().
getDeviceBuffer
());
ewaldForcesKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cosSinSums
->
getDeviceBuffer
());
}
}
if
(
exceptionIndices
!=
NULL
)
cl
.
executeKernel
(
exceptionsKernel
,
exceptionIndices
->
getSize
());
if
(
cosSinSums
!=
NULL
)
{
cl
.
executeKernel
(
ewaldSumsKernel
,
cosSinSums
->
getSize
());
cl
.
executeKernel
(
ewaldForcesKernel
,
cl
.
getNumAtoms
());
}
}
...
...
@@ -718,6 +737,10 @@ OpenCLCalcCustomNonbondedForceKernel::~OpenCLCalcCustomNonbondedForceKernel() {
delete
exceptionParams
;
if
(
exceptionIndices
!=
NULL
)
delete
exceptionIndices
;
if
(
tabulatedFunctionParams
!=
NULL
)
delete
tabulatedFunctionParams
;
for
(
int
i
=
0
;
i
<
(
int
)
tabulatedFunctions
.
size
();
i
++
)
delete
tabulatedFunctions
[
i
];
}
void
OpenCLCalcCustomNonbondedForceKernel
::
initialize
(
const
System
&
system
,
const
CustomNonbondedForce
&
force
)
{
...
...
@@ -746,9 +769,12 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
// Record parameters and exclusions.
int
numParticles
=
force
.
getNumParticles
();
string
extraArguments
;
params
=
new
OpenCLArray
<
mm_float4
>
(
cl
,
numParticles
,
"customNonbondedParameters"
);
if
(
force
.
getNumGlobalParameters
()
>
0
)
if
(
force
.
getNumGlobalParameters
()
>
0
)
{
globals
=
new
OpenCLArray
<
cl_float
>
(
cl
,
force
.
getNumGlobalParameters
(),
"customNonbondedGlobals"
,
false
,
CL_MEM_READ_ONLY
);
extraArguments
+=
", __constant float* globals"
;
}
vector
<
mm_float4
>
paramVec
(
numParticles
);
vector
<
vector
<
int
>
>
exclusionList
(
numParticles
);
for
(
int
i
=
0
;
i
<
numParticles
;
i
++
)
{
...
...
@@ -764,21 +790,80 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
paramVec
[
i
].
w
=
(
cl_float
)
parameters
[
3
];
exclusionList
[
i
].
push_back
(
i
);
}
for
(
int
i
=
0
;
i
<
(
int
)
exclusions
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
(
int
)
exclusions
.
size
();
i
++
)
{
exclusionList
[
exclusions
[
i
].
first
].
push_back
(
exclusions
[
i
].
second
);
exclusionList
[
exclusions
[
i
].
second
].
push_back
(
exclusions
[
i
].
first
);
}
params
->
upload
(
paramVec
);
// This class serves as a placeholder for custom functions in expressions.
class
FunctionPlaceholder
:
public
Lepton
::
CustomFunction
{
public:
int
getNumArguments
()
const
{
return
1
;
}
double
evaluate
(
const
double
*
arguments
)
const
{
return
0.0
;
}
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
return
0.0
;
}
CustomFunction
*
clone
()
const
{
return
new
FunctionPlaceholder
();
}
};
// Record the tabulated functions.
FunctionPlaceholder
*
fp
=
new
FunctionPlaceholder
();
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
mm_float4
>
tabulatedFunctionParamsVec
(
force
.
getNumFunctions
());
for
(
int
i
=
0
;
i
<
force
.
getNumFunctions
();
i
++
)
{
string
name
;
vector
<
double
>
values
;
double
min
,
max
;
bool
interpolating
;
force
.
getFunctionParameters
(
i
,
name
,
values
,
min
,
max
,
interpolating
);
// gpuSetTabulatedFunction(gpu, i, name, values, min, max, interpolating);
string
arrayName
=
prefix
+
"table"
+
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
fp
;
tabulatedFunctionParamsVec
[
i
]
=
(
mm_float4
)
{(
float
)
min
,
(
float
)
max
,
(
float
)
((
values
.
size
()
-
1
)
/
(
max
-
min
)),
0.0
f
};
// First create a padded set of function values.
vector
<
double
>
padded
(
values
.
size
()
+
2
);
padded
[
0
]
=
2
*
values
[
0
]
-
values
[
1
];
for
(
int
i
=
0
;
i
<
(
int
)
values
.
size
();
i
++
)
padded
[
i
+
1
]
=
values
[
i
];
padded
[
padded
.
size
()
-
1
]
=
2
*
values
[
values
.
size
()
-
1
]
-
values
[
values
.
size
()
-
2
];
// Now compute the spline coefficients.
vector
<
mm_float4
>
f
(
values
.
size
()
-
1
);
for
(
int
i
=
0
;
i
<
(
int
)
values
.
size
()
-
1
;
i
++
)
{
if
(
interpolating
)
f
[
i
]
=
(
mm_float4
)
{(
cl_float
)
padded
[
i
+
1
],
(
cl_float
)
(
0.5
*
(
-
padded
[
i
]
+
padded
[
i
+
2
])),
(
cl_float
)
(
0.5
*
(
2.0
*
padded
[
i
]
-
5.0
*
padded
[
i
+
1
]
+
4.0
*
padded
[
i
+
2
]
-
padded
[
i
+
3
])),
(
cl_float
)
(
0.5
*
(
-
padded
[
i
]
+
3.0
*
padded
[
i
+
1
]
-
3.0
*
padded
[
i
+
2
]
+
padded
[
i
+
3
]))};
else
f
[
i
]
=
(
mm_float4
)
{(
cl_float
)
((
padded
[
i
]
+
4.0
*
padded
[
i
+
1
]
+
padded
[
i
+
2
])
/
6.0
),
(
cl_float
)
((
-
3.0
*
padded
[
i
]
+
3.0
*
padded
[
i
+
2
])
/
6.0
),
(
cl_float
)
((
3.0
*
padded
[
i
]
-
6.0
*
padded
[
i
+
1
]
+
3.0
*
padded
[
i
+
2
])
/
6.0
),
(
cl_float
)
((
-
padded
[
i
]
+
3.0
*
padded
[
i
+
1
]
-
3.0
*
padded
[
i
+
2
]
+
padded
[
i
+
3
])
/
6.0
)};
}
tabulatedFunctions
.
push_back
(
new
OpenCLArray
<
mm_float4
>
(
cl
,
values
.
size
()
-
1
,
"TabulatedFunction"
));
tabulatedFunctions
[
tabulatedFunctions
.
size
()
-
1
]
->
upload
(
f
);
cl
.
getNonbondedUtilities
().
addArgument
(
OpenCLNonbondedUtilities
::
ParameterInfo
(
arrayName
,
"float4"
,
sizeof
(
cl_float4
),
tabulatedFunctions
[
tabulatedFunctions
.
size
()
-
1
]
->
getDeviceBuffer
()));
extraArguments
+=
", __constant float4* "
+
arrayName
;
}
if
(
force
.
getNumFunctions
()
>
0
)
{
tabulatedFunctionParams
=
new
OpenCLArray
<
mm_float4
>
(
cl
,
tabulatedFunctionParamsVec
.
size
(),
"tabulatedFunctionParameters"
,
false
,
CL_MEM_READ_ONLY
);
tabulatedFunctionParams
->
upload
(
tabulatedFunctionParamsVec
);
cl
.
getNonbondedUtilities
().
addArgument
(
OpenCLNonbondedUtilities
::
ParameterInfo
(
prefix
+
"functionParams"
,
"float4"
,
sizeof
(
cl_float4
),
tabulatedFunctionParams
->
getDeviceBuffer
()));
extraArguments
+=
", __constant float4* "
+
prefix
+
"functionParams"
;
}
// Record information for the expressions.
...
...
@@ -799,8 +884,11 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
globals
->
upload
(
globalParamValues
);
bool
useCutoff
=
(
force
.
getNonbondedMethod
()
!=
CustomNonbondedForce
::
NoCutoff
);
bool
usePeriodic
=
(
force
.
getNonbondedMethod
()
!=
CustomNonbondedForce
::
NoCutoff
&&
force
.
getNonbondedMethod
()
!=
CustomNonbondedForce
::
CutoffNonPeriodic
);
Lepton
::
ParsedExpression
energyExpression
=
Lepton
::
Parser
::
parse
(
force
.
getEnergyFunction
()).
optimize
();
Lepton
::
ParsedExpression
energyExpression
=
Lepton
::
Parser
::
parse
(
force
.
getEnergyFunction
()
,
functions
).
optimize
();
Lepton
::
ParsedExpression
forceExpression
=
energyExpression
.
differentiate
(
"r"
).
optimize
();
map
<
string
,
Lepton
::
ParsedExpression
>
forceExpressions
;
forceExpressions
[
"tempEnergy += "
]
=
energyExpression
;
forceExpressions
[
"tempForce -= "
]
=
forceExpression
;
// Create the kernels.
...
...
@@ -824,13 +912,13 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
forceVariables
[
name
]
=
prefix
+
value
;
exceptionVariables
[
name
]
=
value
;
}
string
stream
compute
;
map
<
string
,
Lepton
::
ParsedExpression
>
paramExpressions
;
for
(
int
i
=
0
;
i
<
force
.
getNumParameters
();
i
++
)
{
Lepton
::
ParsedExpression
expression
=
Lepton
::
Parser
::
parse
(
force
.
getParameterCombiningRule
(
i
)).
optimize
();
compute
<<
"float "
<<
prefix
<<
force
.
getParameterName
(
i
)
<<
" = "
<<
OpenCLExpressionUtilities
::
createExpression
(
expression
,
paramVariables
)
<<
";
\n
"
;
paramExpressions
[
"float "
+
prefix
+
force
.
getParameterName
(
i
)
+
" = "
]
=
Lepton
::
Parser
::
parse
(
force
.
getParameterCombiningRule
(
i
)).
optimize
();
}
compute
<<
"tempEnergy += "
<<
OpenCLExpressionUtilities
::
createExpression
(
energyExpression
,
forceVariables
)
<<
";
\n
"
;
compute
<<
"tempForce -= "
<<
OpenCLExpressionUtilities
::
createExpression
(
forceExpression
,
forceVariables
)
<<
";
\n
"
;
stringstream
compute
;
compute
<<
OpenCLExpressionUtilities
::
createExpressions
(
paramExpressions
,
paramVariables
,
functionDefinitions
,
prefix
+
"param_temp"
,
prefix
+
"functionParams"
);
compute
<<
OpenCLExpressionUtilities
::
createExpressions
(
forceExpressions
,
forceVariables
,
functionDefinitions
,
prefix
+
"force_temp"
,
prefix
+
"functionParams"
);
map
<
string
,
string
>
replacements
;
replacements
[
"COMPUTE_FORCE"
]
=
compute
.
str
();
string
source
=
cl
.
loadSourceFromFile
(
"customNonbonded.cl"
,
replacements
);
...
...
@@ -840,13 +928,20 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
globals
->
upload
(
globalParamValues
);
cl
.
getNonbondedUtilities
().
addArgument
(
OpenCLNonbondedUtilities
::
ParameterInfo
(
prefix
+
"globals"
,
"float"
,
sizeof
(
cl_float
),
globals
->
getDeviceBuffer
()));
}
map
<
string
,
Lepton
::
ParsedExpression
>
exceptionExpressions
;
stringstream
computeExceptions
;
computeExceptions
<<
"energy += "
<<
OpenCLExpressionUtilities
::
createExpression
(
energyExpression
,
exceptionVariables
)
<<
";
\n
"
;
computeExceptions
<<
"dEdR = "
<<
OpenCLExpressionUtilities
::
createExpression
(
forceExpression
,
exceptionVariables
)
<<
";
\n
"
;
exceptionExpressions
[
"energy += "
]
=
energyExpression
;
exceptionExpressions
[
"dEdR = "
]
=
forceExpression
;
computeExceptions
<<
OpenCLExpressionUtilities
::
createExpressions
(
exceptionExpressions
,
exceptionVariables
,
functionDefinitions
,
"temp"
,
prefix
+
"functionParams"
);
replacements
[
"COMPUTE_FORCE"
]
=
computeExceptions
.
str
();
replacements
[
"EXTRA_ARGUMENTS"
]
=
extraArguments
;
map
<
string
,
string
>
defines
;
if
(
globals
!=
NULL
)
defines
[
"HAS_GLOBALS"
]
=
"1"
;
defines
[
"CUTOFF_SQUARED"
]
=
doubleToString
(
force
.
getCutoffDistance
()
*
force
.
getCutoffDistance
());
Vec3
boxVectors
[
3
];
system
.
getPeriodicBoxVectors
(
boxVectors
[
0
],
boxVectors
[
1
],
boxVectors
[
2
]);
defines
[
"PERIODIC_BOX_SIZE_X"
]
=
doubleToString
(
boxVectors
[
0
][
0
]);
defines
[
"PERIODIC_BOX_SIZE_Y"
]
=
doubleToString
(
boxVectors
[
1
][
1
]);
defines
[
"PERIODIC_BOX_SIZE_Z"
]
=
doubleToString
(
boxVectors
[
2
][
2
]);
cl
::
Program
program
=
cl
.
createProgram
(
cl
.
loadSourceFromFile
(
"customNonbondedExceptions.cl"
,
replacements
),
defines
);
exceptionsKernel
=
cl
::
Kernel
(
program
,
"computeCustomNonbondedExceptions"
);
...
...
@@ -880,23 +975,22 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
maxBuffers
=
max
(
maxBuffers
,
forceBufferCounter
[
i
]);
}
cl
.
addForce
(
new
OpenCLCustomNonbondedForceInfo
(
maxBuffers
,
force
));
delete
fp
;
}
void
OpenCLCalcCustomNonbondedForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
if
(
exceptionParams
!=
NULL
)
{
if
(
!
has
Creat
edKernel
s
)
{
has
Creat
edKernel
s
=
true
;
if
(
!
has
Initializ
edKernel
)
{
has
Initializ
edKernel
=
true
;
exceptionsKernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
exceptionsKernel
.
setArg
<
cl_int
>
(
1
,
exceptionParams
->
getSize
());
exceptionsKernel
.
setArg
<
cl_float
>
(
2
,
cl
.
getNonbondedUtilities
().
getCutoffDistance
()
*
cl
.
getNonbondedUtilities
().
getCutoffDistance
());
exceptionsKernel
.
setArg
<
mm_float4
>
(
3
,
cl
.
getNonbondedUtilities
().
getPeriodicBoxSize
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
cl
.
getEnergyBuffer
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
cl
.
getPosq
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
exceptionParams
->
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
8
,
exceptionIndices
->
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
cl
.
getEnergyBuffer
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
exceptionParams
->
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
exceptionIndices
->
getDeviceBuffer
());
if
(
globals
!=
NULL
)
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
9
,
globals
->
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
globals
->
getDeviceBuffer
());
}
cl
.
executeKernel
(
exceptionsKernel
,
exceptionIndices
->
getSize
());
}
...
...
platforms/opencl/src/OpenCLKernels.h
View file @
72d59cbe
...
...
@@ -150,8 +150,8 @@ private:
*/
class
OpenCLCalcHarmonicBondForceKernel
:
public
CalcHarmonicBondForceKernel
{
public:
OpenCLCalcHarmonicBondForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcHarmonicBondForceKernel
(
name
,
platform
),
cl
(
cl
),
system
(
system
),
params
(
NULL
),
indices
(
NULL
)
{
OpenCLCalcHarmonicBondForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcHarmonicBondForceKernel
(
name
,
platform
),
hasInitializedKernel
(
false
),
cl
(
cl
),
system
(
system
),
params
(
NULL
),
indices
(
NULL
)
{
}
~
OpenCLCalcHarmonicBondForceKernel
();
/**
...
...
@@ -176,6 +176,7 @@ public:
double
executeEnergy
(
ContextImpl
&
context
);
private:
int
numBonds
;
bool
hasInitializedKernel
;
OpenCLContext
&
cl
;
System
&
system
;
OpenCLArray
<
mm_float2
>*
params
;
...
...
@@ -188,7 +189,8 @@ private:
*/
class
OpenCLCalcHarmonicAngleForceKernel
:
public
CalcHarmonicAngleForceKernel
{
public:
OpenCLCalcHarmonicAngleForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcHarmonicAngleForceKernel
(
name
,
platform
),
cl
(
cl
),
system
(
system
)
{
OpenCLCalcHarmonicAngleForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcHarmonicAngleForceKernel
(
name
,
platform
),
hasInitializedKernel
(
false
),
cl
(
cl
),
system
(
system
)
{
}
~
OpenCLCalcHarmonicAngleForceKernel
();
/**
...
...
@@ -213,6 +215,7 @@ public:
double
executeEnergy
(
ContextImpl
&
context
);
private:
int
numAngles
;
bool
hasInitializedKernel
;
OpenCLContext
&
cl
;
System
&
system
;
OpenCLArray
<
mm_float2
>*
params
;
...
...
@@ -225,7 +228,8 @@ private:
*/
class
OpenCLCalcPeriodicTorsionForceKernel
:
public
CalcPeriodicTorsionForceKernel
{
public:
OpenCLCalcPeriodicTorsionForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcPeriodicTorsionForceKernel
(
name
,
platform
),
cl
(
cl
),
system
(
system
)
{
OpenCLCalcPeriodicTorsionForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcPeriodicTorsionForceKernel
(
name
,
platform
),
hasInitializedKernel
(
false
),
cl
(
cl
),
system
(
system
)
{
}
~
OpenCLCalcPeriodicTorsionForceKernel
();
/**
...
...
@@ -250,6 +254,7 @@ public:
double
executeEnergy
(
ContextImpl
&
context
);
private:
int
numTorsions
;
bool
hasInitializedKernel
;
OpenCLContext
&
cl
;
System
&
system
;
OpenCLArray
<
mm_float4
>*
params
;
...
...
@@ -262,7 +267,8 @@ private:
*/
class
OpenCLCalcRBTorsionForceKernel
:
public
CalcRBTorsionForceKernel
{
public:
OpenCLCalcRBTorsionForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcRBTorsionForceKernel
(
name
,
platform
),
cl
(
cl
),
system
(
system
)
{
OpenCLCalcRBTorsionForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcRBTorsionForceKernel
(
name
,
platform
),
hasInitializedKernel
(
false
),
cl
(
cl
),
system
(
system
)
{
}
~
OpenCLCalcRBTorsionForceKernel
();
/**
...
...
@@ -287,6 +293,7 @@ public:
double
executeEnergy
(
ContextImpl
&
context
);
private:
int
numTorsions
;
bool
hasInitializedKernel
;
OpenCLContext
&
cl
;
System
&
system
;
OpenCLArray
<
mm_float8
>*
params
;
...
...
@@ -299,8 +306,8 @@ private:
*/
class
OpenCLCalcNonbondedForceKernel
:
public
CalcNonbondedForceKernel
{
public:
OpenCLCalcNonbondedForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcNonbondedForceKernel
(
name
,
platform
),
cl
(
cl
),
sigmaEpsilon
(
NULL
),
exceptionParams
(
NULL
),
exceptionIndices
(
NULL
),
cosSinSums
(
NULL
)
{
OpenCLCalcNonbondedForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcNonbondedForceKernel
(
name
,
platform
),
hasInitializedKernel
(
false
),
cl
(
cl
),
sigmaEpsilon
(
NULL
),
exceptionParams
(
NULL
),
exceptionIndices
(
NULL
),
cosSinSums
(
NULL
)
{
}
~
OpenCLCalcNonbondedForceKernel
();
/**
...
...
@@ -325,6 +332,7 @@ public:
double
executeEnergy
(
ContextImpl
&
context
);
private:
OpenCLContext
&
cl
;
bool
hasInitializedKernel
;
OpenCLArray
<
mm_float2
>*
sigmaEpsilon
;
OpenCLArray
<
mm_float4
>*
exceptionParams
;
OpenCLArray
<
mm_int4
>*
exceptionIndices
;
...
...
@@ -341,7 +349,7 @@ private:
class
OpenCLCalcCustomNonbondedForceKernel
:
public
CalcCustomNonbondedForceKernel
{
public:
OpenCLCalcCustomNonbondedForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcCustomNonbondedForceKernel
(
name
,
platform
),
has
Creat
edKernel
s
(
false
),
cl
(
cl
),
params
(
NULL
),
globals
(
NULL
),
exceptionParams
(
NULL
),
exceptionIndices
(
NULL
),
system
(
system
)
{
has
Initializ
edKernel
(
false
),
cl
(
cl
),
params
(
NULL
),
globals
(
NULL
),
exceptionParams
(
NULL
),
exceptionIndices
(
NULL
),
tabulatedFunctionParams
(
NULL
),
system
(
system
)
{
}
~
OpenCLCalcCustomNonbondedForceKernel
();
/**
...
...
@@ -365,15 +373,17 @@ public:
*/
double
executeEnergy
(
ContextImpl
&
context
);
private:
bool
has
Creat
edKernel
s
;
bool
has
Initializ
edKernel
;
OpenCLContext
&
cl
;
OpenCLArray
<
mm_float4
>*
params
;
OpenCLArray
<
cl_float
>*
globals
;
OpenCLArray
<
mm_float4
>*
exceptionParams
;
OpenCLArray
<
mm_int4
>*
exceptionIndices
;
OpenCLArray
<
mm_float4
>*
tabulatedFunctionParams
;
cl
::
Kernel
exceptionsKernel
;
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
cl_float
>
globalParamValues
;
std
::
vector
<
OpenCLArray
<
mm_float4
>*>
tabulatedFunctions
;
System
&
system
;
};
...
...
platforms/opencl/src/kernels/customNonbondedExceptions.cl
View file @
72d59cbe
...
...
@@ -2,13 +2,9 @@
*
Compute
custom
nonbonded
exceptions.
*/
__kernel
void
computeCustomNonbondedExceptions
(
int
numAtoms,
int
numExceptions,
float
cutoffSquared,
float4
periodicBoxSize,
__global
float4*
forceBuffers,
__global
float*
energyBuffer,
__kernel
void
computeCustomNonbondedExceptions
(
int
numAtoms,
int
numExceptions,
__global
float4*
forceBuffers,
__global
float*
energyBuffer,
__global
float4*
posq,
__global
float4*
params,
__global
int4*
indices
#
ifdef
HAS_GLOBALS
,
__constant
float*
globals
)
{
#
else
)
{
#
endif
EXTRA_ARGUMENTS
)
{
int
index
=
get_global_id
(
0
)
;
float
energy
=
0.0f
;
while
(
index
<
numExceptions
)
{
...
...
@@ -18,15 +14,15 @@ __kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions,
float4
exceptionParams
=
params[index]
;
float4
delta
=
posq[atoms.y]-posq[atoms.x]
;
#
ifdef
USE_PERIODIC
delta.x
-=
floor
(
delta.x/
periodicBoxSize.x+0.5f
)
*periodicBoxSize.x
;
delta.y
-=
floor
(
delta.y/
periodicBoxSize.y+0.5f
)
*periodicBoxSize.y
;
delta.z
-=
floor
(
delta.z/
periodicBoxSize.z+0.5f
)
*periodicBoxSize.z
;
delta.x
-=
floor
(
delta.x/
PERIODIC_BOX_SIZE_X+0.5f
)
*PERIODIC_BOX_SIZE_X
;
delta.y
-=
floor
(
delta.y/
PERIODIC_BOX_SIZE_Y+0.5f
)
*PERIODIC_BOX_SIZE_Y
;
delta.z
-=
floor
(
delta.z/
PERIODIC_BOX_SIZE_Z+0.5f
)
*PERIODIC_BOX_SIZE_Z
;
#
endif
//
Compute
the
force.
float
r2
=
delta.x*delta.x
+
delta.y*delta.y
+
delta.z*delta.z
;
#
ifdef
USE_CUTOFF
if
(
r2
>
cutoffSquared
)
{
if
(
r2
>
CUTOFF_SQUARED
)
{
#
else
{
#
endif
...
...
platforms/opencl/tests/TestOpenCLCustomNonbondedForce.cpp
View file @
72d59cbe
...
...
@@ -241,8 +241,8 @@ int main() {
testExceptions
();
testCutoff
();
testPeriodic
();
//
testTabulatedFunction(true);
//
testTabulatedFunction(false);
testTabulatedFunction
(
true
);
testTabulatedFunction
(
false
);
}
catch
(
const
exception
&
e
)
{
cout
<<
"exception: "
<<
e
.
what
()
<<
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment