Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
72d59cbe
Commit
72d59cbe
authored
Nov 02, 2009
by
Peter Eastman
Browse files
Finished OpenCL implementation of CustomNonbondedForce. Also implemented a few optimizations.
parent
2127b8dd
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
366 additions
and
123 deletions
+366
-123
libraries/lepton/include/lepton/ExpressionTreeNode.h
libraries/lepton/include/lepton/ExpressionTreeNode.h
+2
-0
libraries/lepton/include/lepton/Operation.h
libraries/lepton/include/lepton/Operation.h
+30
-0
libraries/lepton/include/lepton/ParsedExpression.h
libraries/lepton/include/lepton/ParsedExpression.h
+5
-0
libraries/lepton/src/ExpressionTreeNode.cpp
libraries/lepton/src/ExpressionTreeNode.cpp
+13
-0
libraries/lepton/src/ParsedExpression.cpp
libraries/lepton/src/ParsedExpression.cpp
+5
-0
platforms/opencl/src/OpenCLExpressionUtilities.cpp
platforms/opencl/src/OpenCLExpressionUtilities.cpp
+113
-31
platforms/opencl/src/OpenCLExpressionUtilities.h
platforms/opencl/src/OpenCLExpressionUtilities.h
+8
-2
platforms/opencl/src/OpenCLKernels.cpp
platforms/opencl/src/OpenCLKernels.cpp
+163
-69
platforms/opencl/src/OpenCLKernels.h
platforms/opencl/src/OpenCLKernels.h
+19
-9
platforms/opencl/src/kernels/customNonbondedExceptions.cl
platforms/opencl/src/kernels/customNonbondedExceptions.cl
+6
-10
platforms/opencl/tests/TestOpenCLCustomNonbondedForce.cpp
platforms/opencl/tests/TestOpenCLCustomNonbondedForce.cpp
+2
-2
No files found.
libraries/lepton/include/lepton/ExpressionTreeNode.h
View file @
72d59cbe
...
@@ -84,6 +84,8 @@ public:
...
@@ -84,6 +84,8 @@ public:
ExpressionTreeNode
(
const
ExpressionTreeNode
&
node
);
ExpressionTreeNode
(
const
ExpressionTreeNode
&
node
);
ExpressionTreeNode
();
ExpressionTreeNode
();
~
ExpressionTreeNode
();
~
ExpressionTreeNode
();
bool
operator
==
(
const
ExpressionTreeNode
&
node
)
const
;
bool
operator
!=
(
const
ExpressionTreeNode
&
node
)
const
;
ExpressionTreeNode
&
operator
=
(
const
ExpressionTreeNode
&
node
);
ExpressionTreeNode
&
operator
=
(
const
ExpressionTreeNode
&
node
);
/**
/**
* Get the Operation performed by this node.
* Get the Operation performed by this node.
...
...
libraries/lepton/include/lepton/Operation.h
View file @
72d59cbe
...
@@ -95,6 +95,12 @@ public:
...
@@ -95,6 +95,12 @@ public:
* @param variable the variable with respect to which the derivate should be taken
* @param variable the variable with respect to which the derivate should be taken
*/
*/
virtual
ExpressionTreeNode
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
=
0
;
virtual
ExpressionTreeNode
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
=
0
;
virtual
bool
operator
!=
(
const
Operation
&
op
)
const
{
return
op
.
getId
()
!=
getId
();
}
virtual
bool
operator
==
(
const
Operation
&
op
)
const
{
return
!
(
*
this
!=
op
);
}
class
Constant
;
class
Constant
;
class
Variable
;
class
Variable
;
class
Custom
;
class
Custom
;
...
@@ -149,6 +155,10 @@ public:
...
@@ -149,6 +155,10 @@ public:
double
getValue
()
const
{
double
getValue
()
const
{
return
value
;
return
value
;
}
}
bool
operator
!=
(
const
Operation
&
op
)
const
{
const
Constant
*
o
=
dynamic_cast
<
const
Constant
*>
(
&
op
);
return
(
o
==
NULL
||
o
->
value
!=
value
);
}
private:
private:
double
value
;
double
value
;
};
};
...
@@ -176,6 +186,10 @@ public:
...
@@ -176,6 +186,10 @@ public:
return
iter
->
second
;
return
iter
->
second
;
}
}
ExpressionTreeNode
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
;
ExpressionTreeNode
differentiate
(
const
std
::
vector
<
ExpressionTreeNode
>&
children
,
const
std
::
vector
<
ExpressionTreeNode
>&
childDerivs
,
const
std
::
string
&
variable
)
const
;
bool
operator
!=
(
const
Operation
&
op
)
const
{
const
Variable
*
o
=
dynamic_cast
<
const
Variable
*>
(
&
op
);
return
(
o
==
NULL
||
o
->
name
!=
name
);
}
private:
private:
std
::
string
name
;
std
::
string
name
;
};
};
...
@@ -214,6 +228,10 @@ public:
...
@@ -214,6 +228,10 @@ public:
const
std
::
vector
<
int
>&
getDerivOrder
()
const
{
const
std
::
vector
<
int
>&
getDerivOrder
()
const
{
return
derivOrder
;
return
derivOrder
;
}
}
bool
operator
!=
(
const
Operation
&
op
)
const
{
const
Custom
*
o
=
dynamic_cast
<
const
Custom
*>
(
&
op
);
return
(
o
==
NULL
||
o
->
name
!=
name
||
o
->
isDerivative
!=
isDerivative
||
o
->
derivOrder
!=
derivOrder
);
}
private:
private:
std
::
string
name
;
std
::
string
name
;
CustomFunction
*
function
;
CustomFunction
*
function
;
...
@@ -708,6 +726,10 @@ public:
...
@@ -708,6 +726,10 @@ public:
double
getValue
()
const
{
double
getValue
()
const
{
return
value
;
return
value
;
}
}
bool
operator
!=
(
const
Operation
&
op
)
const
{
const
AddConstant
*
o
=
dynamic_cast
<
const
AddConstant
*>
(
&
op
);
return
(
o
==
NULL
||
o
->
value
!=
value
);
}
private:
private:
double
value
;
double
value
;
};
};
...
@@ -737,6 +759,10 @@ public:
...
@@ -737,6 +759,10 @@ public:
double
getValue
()
const
{
double
getValue
()
const
{
return
value
;
return
value
;
}
}
bool
operator
!=
(
const
Operation
&
op
)
const
{
const
MultiplyConstant
*
o
=
dynamic_cast
<
const
MultiplyConstant
*>
(
&
op
);
return
(
o
==
NULL
||
o
->
value
!=
value
);
}
private:
private:
double
value
;
double
value
;
};
};
...
@@ -766,6 +792,10 @@ public:
...
@@ -766,6 +792,10 @@ public:
double
getValue
()
const
{
double
getValue
()
const
{
return
value
;
return
value
;
}
}
bool
operator
!=
(
const
Operation
&
op
)
const
{
const
PowerConstant
*
o
=
dynamic_cast
<
const
PowerConstant
*>
(
&
op
);
return
(
o
==
NULL
||
o
->
value
!=
value
);
}
private:
private:
double
value
;
double
value
;
};
};
...
...
libraries/lepton/include/lepton/ParsedExpression.h
View file @
72d59cbe
...
@@ -48,6 +48,11 @@ class ExpressionProgram;
...
@@ -48,6 +48,11 @@ class ExpressionProgram;
class
LEPTON_EXPORT
ParsedExpression
{
class
LEPTON_EXPORT
ParsedExpression
{
public:
public:
/**
* Create an uninitialized ParsedExpression. This exists so that ParsedExpressions can be put in STL containers.
* Doing anything with it will produce an exception.
*/
ParsedExpression
();
/**
/**
* Create a ParsedExpression. Normally you will not call this directly. Instead, use the Parser class
* Create a ParsedExpression. Normally you will not call this directly. Instead, use the Parser class
* to parse expression.
* to parse expression.
...
...
libraries/lepton/src/ExpressionTreeNode.cpp
View file @
72d59cbe
...
@@ -70,6 +70,19 @@ ExpressionTreeNode::~ExpressionTreeNode() {
...
@@ -70,6 +70,19 @@ ExpressionTreeNode::~ExpressionTreeNode() {
delete
operation
;
delete
operation
;
}
}
bool
ExpressionTreeNode
::
operator
!=
(
const
ExpressionTreeNode
&
node
)
const
{
if
(
node
.
getOperation
()
!=
getOperation
())
return
true
;
for
(
int
i
=
0
;
i
<
(
int
)
getChildren
().
size
();
i
++
)
if
(
getChildren
()[
i
]
!=
node
.
getChildren
()[
i
])
return
true
;
return
false
;
}
bool
ExpressionTreeNode
::
operator
==
(
const
ExpressionTreeNode
&
node
)
const
{
return
!
(
*
this
!=
node
);
}
ExpressionTreeNode
&
ExpressionTreeNode
::
operator
=
(
const
ExpressionTreeNode
&
node
)
{
ExpressionTreeNode
&
ExpressionTreeNode
::
operator
=
(
const
ExpressionTreeNode
&
node
)
{
if
(
operation
!=
NULL
)
if
(
operation
!=
NULL
)
delete
operation
;
delete
operation
;
...
...
libraries/lepton/src/ParsedExpression.cpp
View file @
72d59cbe
...
@@ -38,10 +38,15 @@
...
@@ -38,10 +38,15 @@
using
namespace
Lepton
;
using
namespace
Lepton
;
using
namespace
std
;
using
namespace
std
;
ParsedExpression
::
ParsedExpression
()
:
rootNode
(
ExpressionTreeNode
())
{
}
ParsedExpression
::
ParsedExpression
(
const
ExpressionTreeNode
&
rootNode
)
:
rootNode
(
rootNode
)
{
ParsedExpression
::
ParsedExpression
(
const
ExpressionTreeNode
&
rootNode
)
:
rootNode
(
rootNode
)
{
}
}
const
ExpressionTreeNode
&
ParsedExpression
::
getRootNode
()
const
{
const
ExpressionTreeNode
&
ParsedExpression
::
getRootNode
()
const
{
if
(
&
rootNode
.
getOperation
()
==
NULL
)
throw
Exception
(
"Illegal call to an initialized ParsedExpression"
);
return
rootNode
;
return
rootNode
;
}
}
...
...
platforms/opencl/src/OpenCLExpressionUtilities.cpp
View file @
72d59cbe
...
@@ -27,7 +27,6 @@
...
@@ -27,7 +27,6 @@
#include "OpenCLExpressionUtilities.h"
#include "OpenCLExpressionUtilities.h"
#include "openmm/OpenMMException.h"
#include "openmm/OpenMMException.h"
#include "lepton/Operation.h"
#include "lepton/Operation.h"
#include <sstream>
using
namespace
OpenMM
;
using
namespace
OpenMM
;
using
namespace
Lepton
;
using
namespace
Lepton
;
...
@@ -46,69 +45,152 @@ static string intToString(int value) {
...
@@ -46,69 +45,152 @@ static string intToString(int value) {
return
s
.
str
();
return
s
.
str
();
}
}
string
OpenCLExpressionUtilities
::
createExpression
(
const
ParsedExpression
&
expression
,
const
map
<
string
,
string
>&
variables
)
{
string
OpenCLExpressionUtilities
::
createExpressions
(
const
map
<
string
,
ParsedExpression
>&
expressions
,
const
map
<
string
,
string
>&
variables
,
return
processExpression
(
expression
.
getRootNode
(),
variables
);
const
vector
<
pair
<
string
,
string
>
>&
functions
,
const
string
&
prefix
,
const
string
&
functionParams
)
{
stringstream
out
;
vector
<
pair
<
ExpressionTreeNode
,
string
>
>
temps
;
for
(
map
<
string
,
ParsedExpression
>::
const_iterator
iter
=
expressions
.
begin
();
iter
!=
expressions
.
end
();
++
iter
)
{
processExpression
(
out
,
iter
->
second
.
getRootNode
(),
temps
,
variables
,
functions
,
prefix
,
functionParams
);
out
<<
iter
->
first
<<
getTempName
(
iter
->
second
.
getRootNode
(),
temps
)
<<
";
\n
"
;
}
return
out
.
str
();
}
}
string
OpenCLExpressionUtilities
::
processExpression
(
const
ExpressionTreeNode
&
node
,
const
map
<
string
,
string
>&
variables
)
{
void
OpenCLExpressionUtilities
::
processExpression
(
stringstream
&
out
,
const
ExpressionTreeNode
&
node
,
vector
<
pair
<
ExpressionTreeNode
,
string
>
>&
temps
,
const
map
<
string
,
string
>&
variables
,
const
vector
<
pair
<
string
,
string
>
>&
functions
,
const
string
&
prefix
,
const
string
&
functionParams
)
{
for
(
int
i
=
0
;
i
<
(
int
)
temps
.
size
();
i
++
)
if
(
temps
[
i
].
first
==
node
)
return
;
for
(
int
i
=
0
;
i
<
(
int
)
node
.
getChildren
().
size
();
i
++
)
processExpression
(
out
,
node
.
getChildren
()[
i
],
temps
,
variables
,
functions
,
prefix
,
functionParams
);
string
name
=
prefix
+
intToString
(
temps
.
size
());
out
<<
"float "
<<
name
<<
" = "
;
switch
(
node
.
getOperation
().
getId
())
{
switch
(
node
.
getOperation
().
getId
())
{
case
Operation
::
CONSTANT
:
case
Operation
::
CONSTANT
:
return
doubleToString
(
dynamic_cast
<
const
Operation
::
Constant
*>
(
&
node
.
getOperation
())
->
getValue
());
out
<<
doubleToString
(
dynamic_cast
<
const
Operation
::
Constant
*>
(
&
node
.
getOperation
())
->
getValue
());
break
;
case
Operation
::
VARIABLE
:
case
Operation
::
VARIABLE
:
{
{
map
<
string
,
string
>::
const_iterator
iter
=
variables
.
find
(
node
.
getOperation
().
getName
());
map
<
string
,
string
>::
const_iterator
iter
=
variables
.
find
(
node
.
getOperation
().
getName
());
if
(
iter
==
variables
.
end
())
if
(
iter
==
variables
.
end
())
throw
OpenMMException
(
"Unknown variable in expression: "
+
node
.
getOperation
().
getName
());
throw
OpenMMException
(
"Unknown variable in expression: "
+
node
.
getOperation
().
getName
());
return
iter
->
second
;
out
<<
iter
->
second
;
break
;
}
case
Operation
::
CUSTOM
:
{
int
i
;
for
(
i
=
0
;
i
<
(
int
)
functions
.
size
()
&&
functions
[
i
].
first
!=
node
.
getOperation
().
getName
();
i
++
)
;
if
(
i
==
functions
.
size
())
throw
OpenMMException
(
"Unknown function in expression: "
+
node
.
getOperation
().
getName
());
out
<<
"0.0f;
\n
"
;
out
<<
"{
\n
"
;
out
<<
"float4 params = "
<<
functionParams
<<
"["
<<
i
<<
"];
\n
"
;
out
<<
"float x = "
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
";
\n
"
;
out
<<
"if (x >= params.x && x <= params.y) {
\n
"
;
out
<<
"int index = (int) (floor((x-params.x)*params.z));
\n
"
;
out
<<
"float4 coeff = "
<<
functions
[
i
].
second
<<
"[index];
\n
"
;
out
<<
"x = (x-params.x)*params.z-index;
\n
"
;
if
(
dynamic_cast
<
const
Operation
::
Custom
*>
(
&
node
.
getOperation
())
->
getDerivOrder
()[
0
]
==
0
)
out
<<
name
<<
" = coeff.x+x*(coeff.y+x*(coeff.z+x*coeff.w));
\n
"
;
else
out
<<
name
<<
" = (coeff.y+x*(2.0f*coeff.z+x*3.0f*coeff.w))*params.z;
\n
"
;
out
<<
"}
\n
"
;
out
<<
"}"
;
break
;
}
}
case
Operation
::
ADD
:
case
Operation
::
ADD
:
return
"("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")+("
+
processExpression
(
node
.
getChildren
()[
1
],
variables
)
+
")"
;
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
"+"
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
);
break
;
case
Operation
::
SUBTRACT
:
case
Operation
::
SUBTRACT
:
return
"("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")-("
+
processExpression
(
node
.
getChildren
()[
1
],
variables
)
+
")"
;
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
"-"
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
);
break
;
case
Operation
::
MULTIPLY
:
case
Operation
::
MULTIPLY
:
return
"("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")*("
+
processExpression
(
node
.
getChildren
()[
1
],
variables
)
+
")"
;
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
"*"
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
);
break
;
case
Operation
::
DIVIDE
:
case
Operation
::
DIVIDE
:
return
"("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")/("
+
processExpression
(
node
.
getChildren
()[
1
],
variables
)
+
")"
;
out
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
"/"
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
);
break
;
case
Operation
::
POWER
:
case
Operation
::
POWER
:
return
"pow(("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
"), ("
+
processExpression
(
node
.
getChildren
()[
1
],
variables
)
+
"))"
;
out
<<
"pow("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
", "
<<
getTempName
(
node
.
getChildren
()[
1
],
temps
)
<<
")"
;
break
;
case
Operation
::
NEGATE
:
case
Operation
::
NEGATE
:
return
"-("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"-"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
break
;
case
Operation
::
SQRT
:
case
Operation
::
SQRT
:
return
"sqrt("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"sqrt("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
EXP
:
case
Operation
::
EXP
:
return
"exp("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"exp("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
LOG
:
case
Operation
::
LOG
:
return
"log("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"log("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
SIN
:
case
Operation
::
SIN
:
return
"sin("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"sin("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
COS
:
case
Operation
::
COS
:
return
"cos("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"cos("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
SEC
:
case
Operation
::
SEC
:
return
"1.0f/cos("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"1.0f/cos("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
CSC
:
case
Operation
::
CSC
:
return
"1.0f/sin("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"1.0f/sin("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
TAN
:
case
Operation
::
TAN
:
return
"tan("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"tan("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
COT
:
case
Operation
::
COT
:
return
"1.0f/tan("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"1.0f/tan("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
ASIN
:
case
Operation
::
ASIN
:
return
"asin("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"asin("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
ACOS
:
case
Operation
::
ACOS
:
return
"acos("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"acos("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
ATAN
:
case
Operation
::
ATAN
:
return
"atan("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"atan("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
")"
;
break
;
case
Operation
::
SQUARE
:
case
Operation
::
SQUARE
:
return
"pow(("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
"), 2.0f)"
;
{
string
arg
=
getTempName
(
node
.
getChildren
()[
0
],
temps
);
out
<<
arg
<<
"*"
<<
arg
;
break
;
}
case
Operation
::
CUBE
:
case
Operation
::
CUBE
:
return
"pow(("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
"), 3.0f)"
;
{
string
arg
=
getTempName
(
node
.
getChildren
()[
0
],
temps
);
out
<<
arg
<<
"*"
<<
arg
<<
"*"
<<
arg
;
break
;
}
case
Operation
::
RECIPROCAL
:
case
Operation
::
RECIPROCAL
:
return
"1.0f/("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
"1.0f/"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
break
;
case
Operation
::
ADD_CONSTANT
:
case
Operation
::
ADD_CONSTANT
:
return
doubleToString
(
dynamic_cast
<
const
Operation
::
AddConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
+
"+("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
doubleToString
(
dynamic_cast
<
const
Operation
::
AddConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
<<
"+"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
break
;
case
Operation
::
MULTIPLY_CONSTANT
:
case
Operation
::
MULTIPLY_CONSTANT
:
return
doubleToString
(
dynamic_cast
<
const
Operation
::
MultiplyConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
+
"*("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
")"
;
out
<<
doubleToString
(
dynamic_cast
<
const
Operation
::
MultiplyConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
<<
"*"
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
);
break
;
case
Operation
::
POWER_CONSTANT
:
case
Operation
::
POWER_CONSTANT
:
return
"pow(("
+
processExpression
(
node
.
getChildren
()[
0
],
variables
)
+
"), "
+
doubleToString
(
dynamic_cast
<
const
Operation
::
PowerConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
+
")"
;
out
<<
"pow("
<<
getTempName
(
node
.
getChildren
()[
0
],
temps
)
<<
", "
<<
doubleToString
(
dynamic_cast
<
const
Operation
::
PowerConstant
*>
(
&
node
.
getOperation
())
->
getValue
())
<<
")"
;
}
break
;
default:
throw
OpenMMException
(
"Internal error: Unknown operation in user-defined expression: "
+
node
.
getOperation
().
getName
());
throw
OpenMMException
(
"Internal error: Unknown operation in user-defined expression: "
+
node
.
getOperation
().
getName
());
}
out
<<
";
\n
"
;
temps
.
push_back
(
make_pair
(
node
,
name
));
}
string
OpenCLExpressionUtilities
::
getTempName
(
const
ExpressionTreeNode
&
node
,
const
vector
<
pair
<
ExpressionTreeNode
,
string
>
>&
temps
)
{
for
(
int
i
=
0
;
i
<
(
int
)
temps
.
size
();
i
++
)
if
(
temps
[
i
].
first
==
node
)
return
temps
[
i
].
second
;
stringstream
out
;
out
<<
"Internal error: No temporary variable for expression node: "
<<
node
;
throw
OpenMMException
(
out
.
str
());
}
}
platforms/opencl/src/OpenCLExpressionUtilities.h
View file @
72d59cbe
...
@@ -30,7 +30,9 @@
...
@@ -30,7 +30,9 @@
#include "lepton/ExpressionTreeNode.h"
#include "lepton/ExpressionTreeNode.h"
#include "lepton/ParsedExpression.h"
#include "lepton/ParsedExpression.h"
#include <map>
#include <map>
#include <sstream>
#include <string>
#include <string>
#include <utility>
namespace
OpenMM
{
namespace
OpenMM
{
...
@@ -41,9 +43,13 @@ namespace OpenMM {
...
@@ -41,9 +43,13 @@ namespace OpenMM {
class
OpenCLExpressionUtilities
{
class
OpenCLExpressionUtilities
{
public:
public:
static
std
::
string
createExpression
(
const
Lepton
::
ParsedExpression
&
expression
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
);
static
std
::
string
createExpressions
(
const
std
::
map
<
std
::
string
,
Lepton
::
ParsedExpression
>&
expressions
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functions
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
);
private:
private:
static
std
::
string
processExpression
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
);
static
void
processExpression
(
std
::
stringstream
&
out
,
const
Lepton
::
ExpressionTreeNode
&
node
,
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
variables
,
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
functions
,
const
std
::
string
&
prefix
,
const
std
::
string
&
functionParams
);
static
std
::
string
getTempName
(
const
Lepton
::
ExpressionTreeNode
&
node
,
const
std
::
vector
<
std
::
pair
<
Lepton
::
ExpressionTreeNode
,
std
::
string
>
>&
temps
);
};
};
}
// namespace OpenMM
}
// namespace OpenMM
...
...
platforms/opencl/src/OpenCLKernels.cpp
View file @
72d59cbe
...
@@ -33,6 +33,7 @@
...
@@ -33,6 +33,7 @@
#include "OpenCLExpressionUtilities.h"
#include "OpenCLExpressionUtilities.h"
#include "OpenCLIntegrationUtilities.h"
#include "OpenCLIntegrationUtilities.h"
#include "OpenCLNonbondedUtilities.h"
#include "OpenCLNonbondedUtilities.h"
#include "lepton/CustomFunction.h"
#include "lepton/Parser.h"
#include "lepton/Parser.h"
#include "lepton/ParsedExpression.h"
#include "lepton/ParsedExpression.h"
#include <cmath>
#include <cmath>
...
@@ -231,6 +232,8 @@ void OpenCLCalcHarmonicBondForceKernel::initialize(const System& system, const H
...
@@ -231,6 +232,8 @@ void OpenCLCalcHarmonicBondForceKernel::initialize(const System& system, const H
}
}
void
OpenCLCalcHarmonicBondForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
void
OpenCLCalcHarmonicBondForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
if
(
!
hasInitializedKernel
)
{
hasInitializedKernel
=
true
;
kernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
kernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
kernel
.
setArg
<
cl_int
>
(
1
,
numBonds
);
kernel
.
setArg
<
cl_int
>
(
1
,
numBonds
);
kernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
...
@@ -238,6 +241,7 @@ void OpenCLCalcHarmonicBondForceKernel::executeForces(ContextImpl& context) {
...
@@ -238,6 +241,7 @@ void OpenCLCalcHarmonicBondForceKernel::executeForces(ContextImpl& context) {
kernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
5
,
params
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
5
,
params
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
6
,
indices
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
6
,
indices
->
getDeviceBuffer
());
}
cl
.
executeKernel
(
kernel
,
numBonds
);
cl
.
executeKernel
(
kernel
,
numBonds
);
}
}
...
@@ -307,6 +311,8 @@ void OpenCLCalcHarmonicAngleForceKernel::initialize(const System& system, const
...
@@ -307,6 +311,8 @@ void OpenCLCalcHarmonicAngleForceKernel::initialize(const System& system, const
}
}
void
OpenCLCalcHarmonicAngleForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
void
OpenCLCalcHarmonicAngleForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
if
(
!
hasInitializedKernel
)
{
hasInitializedKernel
=
true
;
kernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
kernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
kernel
.
setArg
<
cl_int
>
(
1
,
numAngles
);
kernel
.
setArg
<
cl_int
>
(
1
,
numAngles
);
kernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
...
@@ -314,6 +320,7 @@ void OpenCLCalcHarmonicAngleForceKernel::executeForces(ContextImpl& context) {
...
@@ -314,6 +320,7 @@ void OpenCLCalcHarmonicAngleForceKernel::executeForces(ContextImpl& context) {
kernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
5
,
params
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
5
,
params
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
6
,
indices
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
6
,
indices
->
getDeviceBuffer
());
}
cl
.
executeKernel
(
kernel
,
numAngles
);
cl
.
executeKernel
(
kernel
,
numAngles
);
}
}
...
@@ -384,6 +391,8 @@ void OpenCLCalcPeriodicTorsionForceKernel::initialize(const System& system, cons
...
@@ -384,6 +391,8 @@ void OpenCLCalcPeriodicTorsionForceKernel::initialize(const System& system, cons
}
}
void
OpenCLCalcPeriodicTorsionForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
void
OpenCLCalcPeriodicTorsionForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
if
(
!
hasInitializedKernel
)
{
hasInitializedKernel
=
true
;
kernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
kernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
kernel
.
setArg
<
cl_int
>
(
1
,
numTorsions
);
kernel
.
setArg
<
cl_int
>
(
1
,
numTorsions
);
kernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
...
@@ -391,6 +400,7 @@ void OpenCLCalcPeriodicTorsionForceKernel::executeForces(ContextImpl& context) {
...
@@ -391,6 +400,7 @@ void OpenCLCalcPeriodicTorsionForceKernel::executeForces(ContextImpl& context) {
kernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
5
,
params
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
5
,
params
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
6
,
indices
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
6
,
indices
->
getDeviceBuffer
());
}
cl
.
executeKernel
(
kernel
,
numTorsions
);
cl
.
executeKernel
(
kernel
,
numTorsions
);
}
}
...
@@ -461,6 +471,8 @@ void OpenCLCalcRBTorsionForceKernel::initialize(const System& system, const RBTo
...
@@ -461,6 +471,8 @@ void OpenCLCalcRBTorsionForceKernel::initialize(const System& system, const RBTo
}
}
void
OpenCLCalcRBTorsionForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
void
OpenCLCalcRBTorsionForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
if
(
!
hasInitializedKernel
)
{
hasInitializedKernel
=
true
;
kernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
kernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
kernel
.
setArg
<
cl_int
>
(
1
,
numTorsions
);
kernel
.
setArg
<
cl_int
>
(
1
,
numTorsions
);
kernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
...
@@ -468,6 +480,7 @@ void OpenCLCalcRBTorsionForceKernel::executeForces(ContextImpl& context) {
...
@@ -468,6 +480,7 @@ void OpenCLCalcRBTorsionForceKernel::executeForces(ContextImpl& context) {
kernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
5
,
params
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
5
,
params
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
6
,
indices
->
getDeviceBuffer
());
kernel
.
setArg
<
cl
::
Buffer
>
(
6
,
indices
->
getDeviceBuffer
());
}
cl
.
executeKernel
(
kernel
,
numTorsions
);
cl
.
executeKernel
(
kernel
,
numTorsions
);
}
}
...
@@ -639,6 +652,8 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
...
@@ -639,6 +652,8 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
}
}
void
OpenCLCalcNonbondedForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
void
OpenCLCalcNonbondedForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
if
(
!
hasInitializedKernel
)
{
hasInitializedKernel
=
true
;
if
(
exceptionIndices
!=
NULL
)
{
if
(
exceptionIndices
!=
NULL
)
{
int
numExceptions
=
exceptionIndices
->
getSize
();
int
numExceptions
=
exceptionIndices
->
getSize
();
exceptionsKernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
exceptionsKernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
...
@@ -650,16 +665,20 @@ void OpenCLCalcNonbondedForceKernel::executeForces(ContextImpl& context) {
...
@@ -650,16 +665,20 @@ void OpenCLCalcNonbondedForceKernel::executeForces(ContextImpl& context) {
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
cl
.
getPosq
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
cl
.
getPosq
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
exceptionParams
->
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
exceptionParams
->
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
8
,
exceptionIndices
->
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
8
,
exceptionIndices
->
getDeviceBuffer
());
cl
.
executeKernel
(
exceptionsKernel
,
numExceptions
);
}
}
if
(
cosSinSums
!=
NULL
)
{
if
(
cosSinSums
!=
NULL
)
{
ewaldSumsKernel
.
setArg
<
cl
::
Buffer
>
(
0
,
cl
.
getEnergyBuffer
().
getDeviceBuffer
());
ewaldSumsKernel
.
setArg
<
cl
::
Buffer
>
(
0
,
cl
.
getEnergyBuffer
().
getDeviceBuffer
());
ewaldSumsKernel
.
setArg
<
cl
::
Buffer
>
(
1
,
cl
.
getPosq
().
getDeviceBuffer
());
ewaldSumsKernel
.
setArg
<
cl
::
Buffer
>
(
1
,
cl
.
getPosq
().
getDeviceBuffer
());
ewaldSumsKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cosSinSums
->
getDeviceBuffer
());
ewaldSumsKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cosSinSums
->
getDeviceBuffer
());
cl
.
executeKernel
(
ewaldSumsKernel
,
cosSinSums
->
getSize
());
ewaldForcesKernel
.
setArg
<
cl
::
Buffer
>
(
0
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
ewaldForcesKernel
.
setArg
<
cl
::
Buffer
>
(
0
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
ewaldForcesKernel
.
setArg
<
cl
::
Buffer
>
(
1
,
cl
.
getPosq
().
getDeviceBuffer
());
ewaldForcesKernel
.
setArg
<
cl
::
Buffer
>
(
1
,
cl
.
getPosq
().
getDeviceBuffer
());
ewaldForcesKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cosSinSums
->
getDeviceBuffer
());
ewaldForcesKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cosSinSums
->
getDeviceBuffer
());
}
}
if
(
exceptionIndices
!=
NULL
)
cl
.
executeKernel
(
exceptionsKernel
,
exceptionIndices
->
getSize
());
if
(
cosSinSums
!=
NULL
)
{
cl
.
executeKernel
(
ewaldSumsKernel
,
cosSinSums
->
getSize
());
cl
.
executeKernel
(
ewaldForcesKernel
,
cl
.
getNumAtoms
());
cl
.
executeKernel
(
ewaldForcesKernel
,
cl
.
getNumAtoms
());
}
}
}
}
...
@@ -718,6 +737,10 @@ OpenCLCalcCustomNonbondedForceKernel::~OpenCLCalcCustomNonbondedForceKernel() {
...
@@ -718,6 +737,10 @@ OpenCLCalcCustomNonbondedForceKernel::~OpenCLCalcCustomNonbondedForceKernel() {
delete
exceptionParams
;
delete
exceptionParams
;
if
(
exceptionIndices
!=
NULL
)
if
(
exceptionIndices
!=
NULL
)
delete
exceptionIndices
;
delete
exceptionIndices
;
if
(
tabulatedFunctionParams
!=
NULL
)
delete
tabulatedFunctionParams
;
for
(
int
i
=
0
;
i
<
(
int
)
tabulatedFunctions
.
size
();
i
++
)
delete
tabulatedFunctions
[
i
];
}
}
void
OpenCLCalcCustomNonbondedForceKernel
::
initialize
(
const
System
&
system
,
const
CustomNonbondedForce
&
force
)
{
void
OpenCLCalcCustomNonbondedForceKernel
::
initialize
(
const
System
&
system
,
const
CustomNonbondedForce
&
force
)
{
...
@@ -746,9 +769,12 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
...
@@ -746,9 +769,12 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
// Record parameters and exclusions.
// Record parameters and exclusions.
int
numParticles
=
force
.
getNumParticles
();
int
numParticles
=
force
.
getNumParticles
();
string
extraArguments
;
params
=
new
OpenCLArray
<
mm_float4
>
(
cl
,
numParticles
,
"customNonbondedParameters"
);
params
=
new
OpenCLArray
<
mm_float4
>
(
cl
,
numParticles
,
"customNonbondedParameters"
);
if
(
force
.
getNumGlobalParameters
()
>
0
)
if
(
force
.
getNumGlobalParameters
()
>
0
)
{
globals
=
new
OpenCLArray
<
cl_float
>
(
cl
,
force
.
getNumGlobalParameters
(),
"customNonbondedGlobals"
,
false
,
CL_MEM_READ_ONLY
);
globals
=
new
OpenCLArray
<
cl_float
>
(
cl
,
force
.
getNumGlobalParameters
(),
"customNonbondedGlobals"
,
false
,
CL_MEM_READ_ONLY
);
extraArguments
+=
", __constant float* globals"
;
}
vector
<
mm_float4
>
paramVec
(
numParticles
);
vector
<
mm_float4
>
paramVec
(
numParticles
);
vector
<
vector
<
int
>
>
exclusionList
(
numParticles
);
vector
<
vector
<
int
>
>
exclusionList
(
numParticles
);
for
(
int
i
=
0
;
i
<
numParticles
;
i
++
)
{
for
(
int
i
=
0
;
i
<
numParticles
;
i
++
)
{
...
@@ -764,21 +790,80 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
...
@@ -764,21 +790,80 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
paramVec
[
i
].
w
=
(
cl_float
)
parameters
[
3
];
paramVec
[
i
].
w
=
(
cl_float
)
parameters
[
3
];
exclusionList
[
i
].
push_back
(
i
);
exclusionList
[
i
].
push_back
(
i
);
}
}
for
(
int
i
=
0
;
i
<
(
int
)
exclusions
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
(
int
)
exclusions
.
size
();
i
++
)
{
exclusionList
[
exclusions
[
i
].
first
].
push_back
(
exclusions
[
i
].
second
);
exclusionList
[
exclusions
[
i
].
first
].
push_back
(
exclusions
[
i
].
second
);
exclusionList
[
exclusions
[
i
].
second
].
push_back
(
exclusions
[
i
].
first
);
exclusionList
[
exclusions
[
i
].
second
].
push_back
(
exclusions
[
i
].
first
);
}
}
params
->
upload
(
paramVec
);
params
->
upload
(
paramVec
);
// This class serves as a placeholder for custom functions in expressions.
class
FunctionPlaceholder
:
public
Lepton
::
CustomFunction
{
public:
int
getNumArguments
()
const
{
return
1
;
}
double
evaluate
(
const
double
*
arguments
)
const
{
return
0.0
;
}
double
evaluateDerivative
(
const
double
*
arguments
,
const
int
*
derivOrder
)
const
{
return
0.0
;
}
CustomFunction
*
clone
()
const
{
return
new
FunctionPlaceholder
();
}
};
// Record the tabulated functions.
// Record the tabulated functions.
FunctionPlaceholder
*
fp
=
new
FunctionPlaceholder
();
map
<
string
,
Lepton
::
CustomFunction
*>
functions
;
vector
<
pair
<
string
,
string
>
>
functionDefinitions
;
vector
<
mm_float4
>
tabulatedFunctionParamsVec
(
force
.
getNumFunctions
());
for
(
int
i
=
0
;
i
<
force
.
getNumFunctions
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumFunctions
();
i
++
)
{
string
name
;
string
name
;
vector
<
double
>
values
;
vector
<
double
>
values
;
double
min
,
max
;
double
min
,
max
;
bool
interpolating
;
bool
interpolating
;
force
.
getFunctionParameters
(
i
,
name
,
values
,
min
,
max
,
interpolating
);
force
.
getFunctionParameters
(
i
,
name
,
values
,
min
,
max
,
interpolating
);
// gpuSetTabulatedFunction(gpu, i, name, values, min, max, interpolating);
string
arrayName
=
prefix
+
"table"
+
intToString
(
i
);
functionDefinitions
.
push_back
(
make_pair
(
name
,
arrayName
));
functions
[
name
]
=
fp
;
tabulatedFunctionParamsVec
[
i
]
=
(
mm_float4
)
{(
float
)
min
,
(
float
)
max
,
(
float
)
((
values
.
size
()
-
1
)
/
(
max
-
min
)),
0.0
f
};
// First create a padded set of function values.
vector
<
double
>
padded
(
values
.
size
()
+
2
);
padded
[
0
]
=
2
*
values
[
0
]
-
values
[
1
];
for
(
int
i
=
0
;
i
<
(
int
)
values
.
size
();
i
++
)
padded
[
i
+
1
]
=
values
[
i
];
padded
[
padded
.
size
()
-
1
]
=
2
*
values
[
values
.
size
()
-
1
]
-
values
[
values
.
size
()
-
2
];
// Now compute the spline coefficients.
vector
<
mm_float4
>
f
(
values
.
size
()
-
1
);
for
(
int
i
=
0
;
i
<
(
int
)
values
.
size
()
-
1
;
i
++
)
{
if
(
interpolating
)
f
[
i
]
=
(
mm_float4
)
{(
cl_float
)
padded
[
i
+
1
],
(
cl_float
)
(
0.5
*
(
-
padded
[
i
]
+
padded
[
i
+
2
])),
(
cl_float
)
(
0.5
*
(
2.0
*
padded
[
i
]
-
5.0
*
padded
[
i
+
1
]
+
4.0
*
padded
[
i
+
2
]
-
padded
[
i
+
3
])),
(
cl_float
)
(
0.5
*
(
-
padded
[
i
]
+
3.0
*
padded
[
i
+
1
]
-
3.0
*
padded
[
i
+
2
]
+
padded
[
i
+
3
]))};
else
f
[
i
]
=
(
mm_float4
)
{(
cl_float
)
((
padded
[
i
]
+
4.0
*
padded
[
i
+
1
]
+
padded
[
i
+
2
])
/
6.0
),
(
cl_float
)
((
-
3.0
*
padded
[
i
]
+
3.0
*
padded
[
i
+
2
])
/
6.0
),
(
cl_float
)
((
3.0
*
padded
[
i
]
-
6.0
*
padded
[
i
+
1
]
+
3.0
*
padded
[
i
+
2
])
/
6.0
),
(
cl_float
)
((
-
padded
[
i
]
+
3.0
*
padded
[
i
+
1
]
-
3.0
*
padded
[
i
+
2
]
+
padded
[
i
+
3
])
/
6.0
)};
}
tabulatedFunctions
.
push_back
(
new
OpenCLArray
<
mm_float4
>
(
cl
,
values
.
size
()
-
1
,
"TabulatedFunction"
));
tabulatedFunctions
[
tabulatedFunctions
.
size
()
-
1
]
->
upload
(
f
);
cl
.
getNonbondedUtilities
().
addArgument
(
OpenCLNonbondedUtilities
::
ParameterInfo
(
arrayName
,
"float4"
,
sizeof
(
cl_float4
),
tabulatedFunctions
[
tabulatedFunctions
.
size
()
-
1
]
->
getDeviceBuffer
()));
extraArguments
+=
", __constant float4* "
+
arrayName
;
}
if
(
force
.
getNumFunctions
()
>
0
)
{
tabulatedFunctionParams
=
new
OpenCLArray
<
mm_float4
>
(
cl
,
tabulatedFunctionParamsVec
.
size
(),
"tabulatedFunctionParameters"
,
false
,
CL_MEM_READ_ONLY
);
tabulatedFunctionParams
->
upload
(
tabulatedFunctionParamsVec
);
cl
.
getNonbondedUtilities
().
addArgument
(
OpenCLNonbondedUtilities
::
ParameterInfo
(
prefix
+
"functionParams"
,
"float4"
,
sizeof
(
cl_float4
),
tabulatedFunctionParams
->
getDeviceBuffer
()));
extraArguments
+=
", __constant float4* "
+
prefix
+
"functionParams"
;
}
}
// Record information for the expressions.
// Record information for the expressions.
...
@@ -799,8 +884,11 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
...
@@ -799,8 +884,11 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
globals
->
upload
(
globalParamValues
);
globals
->
upload
(
globalParamValues
);
bool
useCutoff
=
(
force
.
getNonbondedMethod
()
!=
CustomNonbondedForce
::
NoCutoff
);
bool
useCutoff
=
(
force
.
getNonbondedMethod
()
!=
CustomNonbondedForce
::
NoCutoff
);
bool
usePeriodic
=
(
force
.
getNonbondedMethod
()
!=
CustomNonbondedForce
::
NoCutoff
&&
force
.
getNonbondedMethod
()
!=
CustomNonbondedForce
::
CutoffNonPeriodic
);
bool
usePeriodic
=
(
force
.
getNonbondedMethod
()
!=
CustomNonbondedForce
::
NoCutoff
&&
force
.
getNonbondedMethod
()
!=
CustomNonbondedForce
::
CutoffNonPeriodic
);
Lepton
::
ParsedExpression
energyExpression
=
Lepton
::
Parser
::
parse
(
force
.
getEnergyFunction
()).
optimize
();
Lepton
::
ParsedExpression
energyExpression
=
Lepton
::
Parser
::
parse
(
force
.
getEnergyFunction
()
,
functions
).
optimize
();
Lepton
::
ParsedExpression
forceExpression
=
energyExpression
.
differentiate
(
"r"
).
optimize
();
Lepton
::
ParsedExpression
forceExpression
=
energyExpression
.
differentiate
(
"r"
).
optimize
();
map
<
string
,
Lepton
::
ParsedExpression
>
forceExpressions
;
forceExpressions
[
"tempEnergy += "
]
=
energyExpression
;
forceExpressions
[
"tempForce -= "
]
=
forceExpression
;
// Create the kernels.
// Create the kernels.
...
@@ -824,13 +912,13 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
...
@@ -824,13 +912,13 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
forceVariables
[
name
]
=
prefix
+
value
;
forceVariables
[
name
]
=
prefix
+
value
;
exceptionVariables
[
name
]
=
value
;
exceptionVariables
[
name
]
=
value
;
}
}
string
stream
compute
;
map
<
string
,
Lepton
::
ParsedExpression
>
paramExpressions
;
for
(
int
i
=
0
;
i
<
force
.
getNumParameters
();
i
++
)
{
for
(
int
i
=
0
;
i
<
force
.
getNumParameters
();
i
++
)
{
Lepton
::
ParsedExpression
expression
=
Lepton
::
Parser
::
parse
(
force
.
getParameterCombiningRule
(
i
)).
optimize
();
paramExpressions
[
"float "
+
prefix
+
force
.
getParameterName
(
i
)
+
" = "
]
=
Lepton
::
Parser
::
parse
(
force
.
getParameterCombiningRule
(
i
)).
optimize
();
compute
<<
"float "
<<
prefix
<<
force
.
getParameterName
(
i
)
<<
" = "
<<
OpenCLExpressionUtilities
::
createExpression
(
expression
,
paramVariables
)
<<
";
\n
"
;
}
}
compute
<<
"tempEnergy += "
<<
OpenCLExpressionUtilities
::
createExpression
(
energyExpression
,
forceVariables
)
<<
";
\n
"
;
stringstream
compute
;
compute
<<
"tempForce -= "
<<
OpenCLExpressionUtilities
::
createExpression
(
forceExpression
,
forceVariables
)
<<
";
\n
"
;
compute
<<
OpenCLExpressionUtilities
::
createExpressions
(
paramExpressions
,
paramVariables
,
functionDefinitions
,
prefix
+
"param_temp"
,
prefix
+
"functionParams"
);
compute
<<
OpenCLExpressionUtilities
::
createExpressions
(
forceExpressions
,
forceVariables
,
functionDefinitions
,
prefix
+
"force_temp"
,
prefix
+
"functionParams"
);
map
<
string
,
string
>
replacements
;
map
<
string
,
string
>
replacements
;
replacements
[
"COMPUTE_FORCE"
]
=
compute
.
str
();
replacements
[
"COMPUTE_FORCE"
]
=
compute
.
str
();
string
source
=
cl
.
loadSourceFromFile
(
"customNonbonded.cl"
,
replacements
);
string
source
=
cl
.
loadSourceFromFile
(
"customNonbonded.cl"
,
replacements
);
...
@@ -840,13 +928,20 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
...
@@ -840,13 +928,20 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
globals
->
upload
(
globalParamValues
);
globals
->
upload
(
globalParamValues
);
cl
.
getNonbondedUtilities
().
addArgument
(
OpenCLNonbondedUtilities
::
ParameterInfo
(
prefix
+
"globals"
,
"float"
,
sizeof
(
cl_float
),
globals
->
getDeviceBuffer
()));
cl
.
getNonbondedUtilities
().
addArgument
(
OpenCLNonbondedUtilities
::
ParameterInfo
(
prefix
+
"globals"
,
"float"
,
sizeof
(
cl_float
),
globals
->
getDeviceBuffer
()));
}
}
map
<
string
,
Lepton
::
ParsedExpression
>
exceptionExpressions
;
stringstream
computeExceptions
;
stringstream
computeExceptions
;
computeExceptions
<<
"energy += "
<<
OpenCLExpressionUtilities
::
createExpression
(
energyExpression
,
exceptionVariables
)
<<
";
\n
"
;
exceptionExpressions
[
"energy += "
]
=
energyExpression
;
computeExceptions
<<
"dEdR = "
<<
OpenCLExpressionUtilities
::
createExpression
(
forceExpression
,
exceptionVariables
)
<<
";
\n
"
;
exceptionExpressions
[
"dEdR = "
]
=
forceExpression
;
computeExceptions
<<
OpenCLExpressionUtilities
::
createExpressions
(
exceptionExpressions
,
exceptionVariables
,
functionDefinitions
,
"temp"
,
prefix
+
"functionParams"
);
replacements
[
"COMPUTE_FORCE"
]
=
computeExceptions
.
str
();
replacements
[
"COMPUTE_FORCE"
]
=
computeExceptions
.
str
();
replacements
[
"EXTRA_ARGUMENTS"
]
=
extraArguments
;
map
<
string
,
string
>
defines
;
map
<
string
,
string
>
defines
;
if
(
globals
!=
NULL
)
defines
[
"CUTOFF_SQUARED"
]
=
doubleToString
(
force
.
getCutoffDistance
()
*
force
.
getCutoffDistance
());
defines
[
"HAS_GLOBALS"
]
=
"1"
;
Vec3
boxVectors
[
3
];
system
.
getPeriodicBoxVectors
(
boxVectors
[
0
],
boxVectors
[
1
],
boxVectors
[
2
]);
defines
[
"PERIODIC_BOX_SIZE_X"
]
=
doubleToString
(
boxVectors
[
0
][
0
]);
defines
[
"PERIODIC_BOX_SIZE_Y"
]
=
doubleToString
(
boxVectors
[
1
][
1
]);
defines
[
"PERIODIC_BOX_SIZE_Z"
]
=
doubleToString
(
boxVectors
[
2
][
2
]);
cl
::
Program
program
=
cl
.
createProgram
(
cl
.
loadSourceFromFile
(
"customNonbondedExceptions.cl"
,
replacements
),
defines
);
cl
::
Program
program
=
cl
.
createProgram
(
cl
.
loadSourceFromFile
(
"customNonbondedExceptions.cl"
,
replacements
),
defines
);
exceptionsKernel
=
cl
::
Kernel
(
program
,
"computeCustomNonbondedExceptions"
);
exceptionsKernel
=
cl
::
Kernel
(
program
,
"computeCustomNonbondedExceptions"
);
...
@@ -880,23 +975,22 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
...
@@ -880,23 +975,22 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
maxBuffers
=
max
(
maxBuffers
,
forceBufferCounter
[
i
]);
maxBuffers
=
max
(
maxBuffers
,
forceBufferCounter
[
i
]);
}
}
cl
.
addForce
(
new
OpenCLCustomNonbondedForceInfo
(
maxBuffers
,
force
));
cl
.
addForce
(
new
OpenCLCustomNonbondedForceInfo
(
maxBuffers
,
force
));
delete
fp
;
}
}
void
OpenCLCalcCustomNonbondedForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
void
OpenCLCalcCustomNonbondedForceKernel
::
executeForces
(
ContextImpl
&
context
)
{
if
(
exceptionParams
!=
NULL
)
{
if
(
exceptionParams
!=
NULL
)
{
if
(
!
has
Creat
edKernel
s
)
{
if
(
!
has
Initializ
edKernel
)
{
has
Creat
edKernel
s
=
true
;
has
Initializ
edKernel
=
true
;
exceptionsKernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
exceptionsKernel
.
setArg
<
cl_int
>
(
0
,
cl
.
getPaddedNumAtoms
());
exceptionsKernel
.
setArg
<
cl_int
>
(
1
,
exceptionParams
->
getSize
());
exceptionsKernel
.
setArg
<
cl_int
>
(
1
,
exceptionParams
->
getSize
());
exceptionsKernel
.
setArg
<
cl_float
>
(
2
,
cl
.
getNonbondedUtilities
().
getCutoffDistance
()
*
cl
.
getNonbondedUtilities
().
getCutoffDistance
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
mm_float4
>
(
3
,
cl
.
getNonbondedUtilities
().
getPeriodicBoxSize
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
cl
.
getEnergyBuffer
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getForceBuffers
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
cl
.
getPosq
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
cl
.
getEnergyBuffer
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
exceptionParams
->
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
cl
.
getPosq
().
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
exceptionIndices
->
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
exceptionParams
->
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
8
,
exceptionIndices
->
getDeviceBuffer
());
if
(
globals
!=
NULL
)
if
(
globals
!=
NULL
)
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
9
,
globals
->
getDeviceBuffer
());
exceptionsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
globals
->
getDeviceBuffer
());
}
}
cl
.
executeKernel
(
exceptionsKernel
,
exceptionIndices
->
getSize
());
cl
.
executeKernel
(
exceptionsKernel
,
exceptionIndices
->
getSize
());
}
}
...
...
platforms/opencl/src/OpenCLKernels.h
View file @
72d59cbe
...
@@ -150,8 +150,8 @@ private:
...
@@ -150,8 +150,8 @@ private:
*/
*/
class
OpenCLCalcHarmonicBondForceKernel
:
public
CalcHarmonicBondForceKernel
{
class
OpenCLCalcHarmonicBondForceKernel
:
public
CalcHarmonicBondForceKernel
{
public:
public:
OpenCLCalcHarmonicBondForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
OpenCLCalcHarmonicBondForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcHarmonicBondForceKernel
(
name
,
platform
),
CalcHarmonicBondForceKernel
(
name
,
platform
),
cl
(
cl
),
system
(
system
),
params
(
NULL
),
indices
(
NULL
)
{
hasInitializedKernel
(
false
),
cl
(
cl
),
system
(
system
),
params
(
NULL
),
indices
(
NULL
)
{
}
}
~
OpenCLCalcHarmonicBondForceKernel
();
~
OpenCLCalcHarmonicBondForceKernel
();
/**
/**
...
@@ -176,6 +176,7 @@ public:
...
@@ -176,6 +176,7 @@ public:
double
executeEnergy
(
ContextImpl
&
context
);
double
executeEnergy
(
ContextImpl
&
context
);
private:
private:
int
numBonds
;
int
numBonds
;
bool
hasInitializedKernel
;
OpenCLContext
&
cl
;
OpenCLContext
&
cl
;
System
&
system
;
System
&
system
;
OpenCLArray
<
mm_float2
>*
params
;
OpenCLArray
<
mm_float2
>*
params
;
...
@@ -188,7 +189,8 @@ private:
...
@@ -188,7 +189,8 @@ private:
*/
*/
class
OpenCLCalcHarmonicAngleForceKernel
:
public
CalcHarmonicAngleForceKernel
{
class
OpenCLCalcHarmonicAngleForceKernel
:
public
CalcHarmonicAngleForceKernel
{
public:
public:
OpenCLCalcHarmonicAngleForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcHarmonicAngleForceKernel
(
name
,
platform
),
cl
(
cl
),
system
(
system
)
{
OpenCLCalcHarmonicAngleForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcHarmonicAngleForceKernel
(
name
,
platform
),
hasInitializedKernel
(
false
),
cl
(
cl
),
system
(
system
)
{
}
}
~
OpenCLCalcHarmonicAngleForceKernel
();
~
OpenCLCalcHarmonicAngleForceKernel
();
/**
/**
...
@@ -213,6 +215,7 @@ public:
...
@@ -213,6 +215,7 @@ public:
double
executeEnergy
(
ContextImpl
&
context
);
double
executeEnergy
(
ContextImpl
&
context
);
private:
private:
int
numAngles
;
int
numAngles
;
bool
hasInitializedKernel
;
OpenCLContext
&
cl
;
OpenCLContext
&
cl
;
System
&
system
;
System
&
system
;
OpenCLArray
<
mm_float2
>*
params
;
OpenCLArray
<
mm_float2
>*
params
;
...
@@ -225,7 +228,8 @@ private:
...
@@ -225,7 +228,8 @@ private:
*/
*/
class
OpenCLCalcPeriodicTorsionForceKernel
:
public
CalcPeriodicTorsionForceKernel
{
class
OpenCLCalcPeriodicTorsionForceKernel
:
public
CalcPeriodicTorsionForceKernel
{
public:
public:
OpenCLCalcPeriodicTorsionForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcPeriodicTorsionForceKernel
(
name
,
platform
),
cl
(
cl
),
system
(
system
)
{
OpenCLCalcPeriodicTorsionForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcPeriodicTorsionForceKernel
(
name
,
platform
),
hasInitializedKernel
(
false
),
cl
(
cl
),
system
(
system
)
{
}
}
~
OpenCLCalcPeriodicTorsionForceKernel
();
~
OpenCLCalcPeriodicTorsionForceKernel
();
/**
/**
...
@@ -250,6 +254,7 @@ public:
...
@@ -250,6 +254,7 @@ public:
double
executeEnergy
(
ContextImpl
&
context
);
double
executeEnergy
(
ContextImpl
&
context
);
private:
private:
int
numTorsions
;
int
numTorsions
;
bool
hasInitializedKernel
;
OpenCLContext
&
cl
;
OpenCLContext
&
cl
;
System
&
system
;
System
&
system
;
OpenCLArray
<
mm_float4
>*
params
;
OpenCLArray
<
mm_float4
>*
params
;
...
@@ -262,7 +267,8 @@ private:
...
@@ -262,7 +267,8 @@ private:
*/
*/
class
OpenCLCalcRBTorsionForceKernel
:
public
CalcRBTorsionForceKernel
{
class
OpenCLCalcRBTorsionForceKernel
:
public
CalcRBTorsionForceKernel
{
public:
public:
OpenCLCalcRBTorsionForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcRBTorsionForceKernel
(
name
,
platform
),
cl
(
cl
),
system
(
system
)
{
OpenCLCalcRBTorsionForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcRBTorsionForceKernel
(
name
,
platform
),
hasInitializedKernel
(
false
),
cl
(
cl
),
system
(
system
)
{
}
}
~
OpenCLCalcRBTorsionForceKernel
();
~
OpenCLCalcRBTorsionForceKernel
();
/**
/**
...
@@ -287,6 +293,7 @@ public:
...
@@ -287,6 +293,7 @@ public:
double
executeEnergy
(
ContextImpl
&
context
);
double
executeEnergy
(
ContextImpl
&
context
);
private:
private:
int
numTorsions
;
int
numTorsions
;
bool
hasInitializedKernel
;
OpenCLContext
&
cl
;
OpenCLContext
&
cl
;
System
&
system
;
System
&
system
;
OpenCLArray
<
mm_float8
>*
params
;
OpenCLArray
<
mm_float8
>*
params
;
...
@@ -299,8 +306,8 @@ private:
...
@@ -299,8 +306,8 @@ private:
*/
*/
class
OpenCLCalcNonbondedForceKernel
:
public
CalcNonbondedForceKernel
{
class
OpenCLCalcNonbondedForceKernel
:
public
CalcNonbondedForceKernel
{
public:
public:
OpenCLCalcNonbondedForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcNonbondedForceKernel
(
name
,
platform
),
cl
(
cl
),
OpenCLCalcNonbondedForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcNonbondedForceKernel
(
name
,
platform
),
sigmaEpsilon
(
NULL
),
exceptionParams
(
NULL
),
exceptionIndices
(
NULL
),
cosSinSums
(
NULL
)
{
hasInitializedKernel
(
false
),
cl
(
cl
),
sigmaEpsilon
(
NULL
),
exceptionParams
(
NULL
),
exceptionIndices
(
NULL
),
cosSinSums
(
NULL
)
{
}
}
~
OpenCLCalcNonbondedForceKernel
();
~
OpenCLCalcNonbondedForceKernel
();
/**
/**
...
@@ -325,6 +332,7 @@ public:
...
@@ -325,6 +332,7 @@ public:
double
executeEnergy
(
ContextImpl
&
context
);
double
executeEnergy
(
ContextImpl
&
context
);
private:
private:
OpenCLContext
&
cl
;
OpenCLContext
&
cl
;
bool
hasInitializedKernel
;
OpenCLArray
<
mm_float2
>*
sigmaEpsilon
;
OpenCLArray
<
mm_float2
>*
sigmaEpsilon
;
OpenCLArray
<
mm_float4
>*
exceptionParams
;
OpenCLArray
<
mm_float4
>*
exceptionParams
;
OpenCLArray
<
mm_int4
>*
exceptionIndices
;
OpenCLArray
<
mm_int4
>*
exceptionIndices
;
...
@@ -341,7 +349,7 @@ private:
...
@@ -341,7 +349,7 @@ private:
class
OpenCLCalcCustomNonbondedForceKernel
:
public
CalcCustomNonbondedForceKernel
{
class
OpenCLCalcCustomNonbondedForceKernel
:
public
CalcCustomNonbondedForceKernel
{
public:
public:
OpenCLCalcCustomNonbondedForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcCustomNonbondedForceKernel
(
name
,
platform
),
OpenCLCalcCustomNonbondedForceKernel
(
std
::
string
name
,
const
Platform
&
platform
,
OpenCLContext
&
cl
,
System
&
system
)
:
CalcCustomNonbondedForceKernel
(
name
,
platform
),
has
Creat
edKernel
s
(
false
),
cl
(
cl
),
params
(
NULL
),
globals
(
NULL
),
exceptionParams
(
NULL
),
exceptionIndices
(
NULL
),
system
(
system
)
{
has
Initializ
edKernel
(
false
),
cl
(
cl
),
params
(
NULL
),
globals
(
NULL
),
exceptionParams
(
NULL
),
exceptionIndices
(
NULL
),
tabulatedFunctionParams
(
NULL
),
system
(
system
)
{
}
}
~
OpenCLCalcCustomNonbondedForceKernel
();
~
OpenCLCalcCustomNonbondedForceKernel
();
/**
/**
...
@@ -365,15 +373,17 @@ public:
...
@@ -365,15 +373,17 @@ public:
*/
*/
double
executeEnergy
(
ContextImpl
&
context
);
double
executeEnergy
(
ContextImpl
&
context
);
private:
private:
bool
has
Creat
edKernel
s
;
bool
has
Initializ
edKernel
;
OpenCLContext
&
cl
;
OpenCLContext
&
cl
;
OpenCLArray
<
mm_float4
>*
params
;
OpenCLArray
<
mm_float4
>*
params
;
OpenCLArray
<
cl_float
>*
globals
;
OpenCLArray
<
cl_float
>*
globals
;
OpenCLArray
<
mm_float4
>*
exceptionParams
;
OpenCLArray
<
mm_float4
>*
exceptionParams
;
OpenCLArray
<
mm_int4
>*
exceptionIndices
;
OpenCLArray
<
mm_int4
>*
exceptionIndices
;
OpenCLArray
<
mm_float4
>*
tabulatedFunctionParams
;
cl
::
Kernel
exceptionsKernel
;
cl
::
Kernel
exceptionsKernel
;
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
std
::
string
>
globalParamNames
;
std
::
vector
<
cl_float
>
globalParamValues
;
std
::
vector
<
cl_float
>
globalParamValues
;
std
::
vector
<
OpenCLArray
<
mm_float4
>*>
tabulatedFunctions
;
System
&
system
;
System
&
system
;
};
};
...
...
platforms/opencl/src/kernels/customNonbondedExceptions.cl
View file @
72d59cbe
...
@@ -2,13 +2,9 @@
...
@@ -2,13 +2,9 @@
*
Compute
custom
nonbonded
exceptions.
*
Compute
custom
nonbonded
exceptions.
*/
*/
__kernel
void
computeCustomNonbondedExceptions
(
int
numAtoms,
int
numExceptions,
float
cutoffSquared,
float4
periodicBoxSize,
__global
float4*
forceBuffers,
__global
float*
energyBuffer,
__kernel
void
computeCustomNonbondedExceptions
(
int
numAtoms,
int
numExceptions,
__global
float4*
forceBuffers,
__global
float*
energyBuffer,
__global
float4*
posq,
__global
float4*
params,
__global
int4*
indices
__global
float4*
posq,
__global
float4*
params,
__global
int4*
indices
#
ifdef
HAS_GLOBALS
EXTRA_ARGUMENTS
)
{
,
__constant
float*
globals
)
{
#
else
)
{
#
endif
int
index
=
get_global_id
(
0
)
;
int
index
=
get_global_id
(
0
)
;
float
energy
=
0.0f
;
float
energy
=
0.0f
;
while
(
index
<
numExceptions
)
{
while
(
index
<
numExceptions
)
{
...
@@ -18,15 +14,15 @@ __kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions,
...
@@ -18,15 +14,15 @@ __kernel void computeCustomNonbondedExceptions(int numAtoms, int numExceptions,
float4
exceptionParams
=
params[index]
;
float4
exceptionParams
=
params[index]
;
float4
delta
=
posq[atoms.y]-posq[atoms.x]
;
float4
delta
=
posq[atoms.y]-posq[atoms.x]
;
#
ifdef
USE_PERIODIC
#
ifdef
USE_PERIODIC
delta.x
-=
floor
(
delta.x/
periodicBoxSize.x+0.5f
)
*periodicBoxSize.x
;
delta.x
-=
floor
(
delta.x/
PERIODIC_BOX_SIZE_X+0.5f
)
*PERIODIC_BOX_SIZE_X
;
delta.y
-=
floor
(
delta.y/
periodicBoxSize.y+0.5f
)
*periodicBoxSize.y
;
delta.y
-=
floor
(
delta.y/
PERIODIC_BOX_SIZE_Y+0.5f
)
*PERIODIC_BOX_SIZE_Y
;
delta.z
-=
floor
(
delta.z/
periodicBoxSize.z+0.5f
)
*periodicBoxSize.z
;
delta.z
-=
floor
(
delta.z/
PERIODIC_BOX_SIZE_Z+0.5f
)
*PERIODIC_BOX_SIZE_Z
;
#
endif
#
endif
//
Compute
the
force.
//
Compute
the
force.
float
r2
=
delta.x*delta.x
+
delta.y*delta.y
+
delta.z*delta.z
;
float
r2
=
delta.x*delta.x
+
delta.y*delta.y
+
delta.z*delta.z
;
#
ifdef
USE_CUTOFF
#
ifdef
USE_CUTOFF
if
(
r2
>
cutoffSquared
)
{
if
(
r2
>
CUTOFF_SQUARED
)
{
#
else
#
else
{
{
#
endif
#
endif
...
...
platforms/opencl/tests/TestOpenCLCustomNonbondedForce.cpp
View file @
72d59cbe
...
@@ -241,8 +241,8 @@ int main() {
...
@@ -241,8 +241,8 @@ int main() {
testExceptions
();
testExceptions
();
testCutoff
();
testCutoff
();
testPeriodic
();
testPeriodic
();
//
testTabulatedFunction(true);
testTabulatedFunction
(
true
);
//
testTabulatedFunction(false);
testTabulatedFunction
(
false
);
}
}
catch
(
const
exception
&
e
)
{
catch
(
const
exception
&
e
)
{
cout
<<
"exception: "
<<
e
.
what
()
<<
endl
;
cout
<<
"exception: "
<<
e
.
what
()
<<
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment