Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
35b9d787
"devtools/packaging/scripts/vscode:/vscode.git/clone" did not exist on "228083e3201d9b44b40ae34718f169372dafb7d4"
Commit
35b9d787
authored
Nov 22, 2013
by
peastman
Browse files
Workarounds for devices whose maximum workgroup size varies between kernels
parent
d59a34fb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
229 additions
and
203 deletions
+229
-203
platforms/opencl/src/OpenCLFFT3D.cpp
platforms/opencl/src/OpenCLFFT3D.cpp
+177
-166
platforms/opencl/src/OpenCLNonbondedUtilities.cpp
platforms/opencl/src/OpenCLNonbondedUtilities.cpp
+46
-34
platforms/opencl/src/OpenCLSort.cpp
platforms/opencl/src/OpenCLSort.cpp
+6
-3
No files found.
platforms/opencl/src/OpenCLFFT3D.cpp
View file @
35b9d787
...
@@ -74,178 +74,189 @@ int OpenCLFFT3D::findLegalDimension(int minimum) {
...
@@ -74,178 +74,189 @@ int OpenCLFFT3D::findLegalDimension(int minimum) {
}
}
cl
::
Kernel
OpenCLFFT3D
::
createKernel
(
int
xsize
,
int
ysize
,
int
zsize
,
int
&
threads
)
{
cl
::
Kernel
OpenCLFFT3D
::
createKernel
(
int
xsize
,
int
ysize
,
int
zsize
,
int
&
threads
)
{
bool
loopRequired
=
(
context
.
getDevice
().
getInfo
<
CL_DEVICE_TYPE
>
()
==
CL_DEVICE_TYPE_CPU
);
int
maxThreads
=
std
::
min
(
256
,
(
int
)
context
.
getDevice
().
getInfo
<
CL_DEVICE_MAX_WORK_GROUP_SIZE
>
());
stringstream
source
;
bool
isCPU
=
context
.
getDevice
().
getInfo
<
CL_DEVICE_TYPE
>
()
==
CL_DEVICE_TYPE_CPU
;
int
blocksPerGroup
=
(
loopRequired
?
1
:
max
(
1
,
256
/
zsize
));
while
(
true
)
{
int
stage
=
0
;
bool
loopRequired
=
(
zsize
>
maxThreads
||
isCPU
);
int
L
=
zsize
;
stringstream
source
;
int
m
=
1
;
int
blocksPerGroup
=
(
loopRequired
?
1
:
max
(
1
,
maxThreads
/
zsize
));
int
stage
=
0
;
int
L
=
zsize
;
int
m
=
1
;
// Factor zsize, generating an appropriate block of code for each factor.
while
(
L
>
1
)
{
int
input
=
stage
%
2
;
int
output
=
1
-
input
;
int
radix
;
if
(
L
%
7
==
0
)
radix
=
7
;
else
if
(
L
%
5
==
0
)
radix
=
5
;
else
if
(
L
%
4
==
0
)
radix
=
4
;
else
if
(
L
%
3
==
0
)
radix
=
3
;
else
if
(
L
%
2
==
0
)
radix
=
2
;
else
throw
OpenMMException
(
"Illegal size for FFT: "
+
context
.
intToString
(
zsize
));
source
<<
"{
\n
"
;
L
=
L
/
radix
;
source
<<
"// Pass "
<<
(
stage
+
1
)
<<
" (radix "
<<
radix
<<
")
\n
"
;
if
(
loopRequired
)
{
source
<<
"for (int i = get_local_id(0); i < "
<<
(
L
*
m
)
<<
"; i += get_local_size(0)) {
\n
"
;
source
<<
"int base = i;
\n
"
;
}
else
{
source
<<
"if (get_local_id(0) < "
<<
(
blocksPerGroup
*
L
*
m
)
<<
") {
\n
"
;
source
<<
"int block = get_local_id(0)/"
<<
(
L
*
m
)
<<
";
\n
"
;
source
<<
"int i = get_local_id(0)-block*"
<<
(
L
*
m
)
<<
";
\n
"
;
source
<<
"int base = i+block*"
<<
zsize
<<
";
\n
"
;
}
source
<<
"int j = i/"
<<
m
<<
";
\n
"
;
if
(
radix
==
7
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c4 = data"
<<
input
<<
"[base+"
<<
(
4
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c5 = data"
<<
input
<<
"[base+"
<<
(
5
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c6 = data"
<<
input
<<
"[base+"
<<
(
6
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c6;
\n
"
;
source
<<
"real2 d1 = c1-c6;
\n
"
;
source
<<
"real2 d2 = c2+c5;
\n
"
;
source
<<
"real2 d3 = c2-c5;
\n
"
;
source
<<
"real2 d4 = c4+c3;
\n
"
;
source
<<
"real2 d5 = c4-c3;
\n
"
;
source
<<
"real2 d6 = d2+d0;
\n
"
;
source
<<
"real2 d7 = d5+d3;
\n
"
;
source
<<
"real2 b0 = c0+d6+d4;
\n
"
;
source
<<
"real2 b1 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
+
cos
(
4
*
M_PI
/
7
)
+
cos
(
6
*
M_PI
/
7
))
/
3
-
1
)
<<
"*(d6+d4);
\n
"
;
source
<<
"real2 b2 = "
<<
context
.
doubleToString
((
2
*
cos
(
2
*
M_PI
/
7
)
-
cos
(
4
*
M_PI
/
7
)
-
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d0-d4);
\n
"
;
source
<<
"real2 b3 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
-
2
*
cos
(
4
*
M_PI
/
7
)
+
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d4-d2);
\n
"
;
source
<<
"real2 b4 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
+
cos
(
4
*
M_PI
/
7
)
-
2
*
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d2-d0);
\n
"
;
source
<<
"real2 b5 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
+
sin
(
4
*
M_PI
/
7
)
-
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d7+d1);
\n
"
;
source
<<
"real2 b6 = -sign*"
<<
context
.
doubleToString
((
2
*
sin
(
2
*
M_PI
/
7
)
-
sin
(
4
*
M_PI
/
7
)
+
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d1-d5);
\n
"
;
source
<<
"real2 b7 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
-
2
*
sin
(
4
*
M_PI
/
7
)
-
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d5-d3);
\n
"
;
source
<<
"real2 b8 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
+
sin
(
4
*
M_PI
/
7
)
+
2
*
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d3-d1);
\n
"
;
source
<<
"real2 t0 = b0+b1;
\n
"
;
source
<<
"real2 t1 = b2+b3;
\n
"
;
source
<<
"real2 t2 = b4-b3;
\n
"
;
source
<<
"real2 t3 = -b2-b4;
\n
"
;
source
<<
"real2 t4 = b6+b7;
\n
"
;
source
<<
"real2 t5 = b8-b7;
\n
"
;
source
<<
"real2 t6 = -b8-b6;
\n
"
;
source
<<
"real2 t7 = t0+t1;
\n
"
;
source
<<
"real2 t8 = t0+t2;
\n
"
;
source
<<
"real2 t9 = t0+t3;
\n
"
;
source
<<
"real2 t10 = (real2) (t4.y+b5.y, -(t4.x+b5.x));
\n
"
;
source
<<
"real2 t11 = (real2) (t5.y+b5.y, -(t5.x+b5.x));
\n
"
;
source
<<
"real2 t12 = (real2) (t6.y+b5.y, -(t6.x+b5.x));
\n
"
;
source
<<
"data"
<<
output
<<
"[base+6*j*"
<<
m
<<
"] = b0;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
7
*
L
)
<<
"], t7-t10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t9-t12);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t8+t11);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+4)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
4
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t8-t11);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+5)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
5
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t9+t12);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+6)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
6
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t7+t10);
\n
"
;
}
else
if
(
radix
==
5
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c4 = data"
<<
input
<<
"[base+"
<<
(
4
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c4;
\n
"
;
source
<<
"real2 d1 = c2+c3;
\n
"
;
source
<<
"real2 d2 = "
<<
context
.
doubleToString
(
sin
(
0.4
*
M_PI
))
<<
"*(c1-c4);
\n
"
;
source
<<
"real2 d3 = "
<<
context
.
doubleToString
(
sin
(
0.4
*
M_PI
))
<<
"*(c2-c3);
\n
"
;
source
<<
"real2 d4 = d0+d1;
\n
"
;
source
<<
"real2 d5 = "
<<
context
.
doubleToString
(
0.25
*
sqrt
(
5.0
))
<<
"*(d0-d1);
\n
"
;
source
<<
"real2 d6 = c0-0.25f*d4;
\n
"
;
source
<<
"real2 d7 = d6+d5;
\n
"
;
source
<<
"real2 d8 = d6-d5;
\n
"
;
string
coeff
=
context
.
doubleToString
(
sin
(
0.2
*
M_PI
)
/
sin
(
0.4
*
M_PI
));
source
<<
"real2 d9 = sign*(real2) (d2.y+"
<<
coeff
<<
"*d3.y, -d2.x-"
<<
coeff
<<
"*d3.x);
\n
"
;
source
<<
"real2 d10 = sign*(real2) ("
<<
coeff
<<
"*d2.y-d3.y, d3.x-"
<<
coeff
<<
"*d2.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+4*j*"
<<
m
<<
"] = c0+d4;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
5
*
L
)
<<
"], d7+d9);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d8+d10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d8-d10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+4)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
4
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d7-d9);
\n
"
;
}
else
if
(
radix
==
4
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c0+c2;
\n
"
;
source
<<
"real2 d1 = c0-c2;
\n
"
;
source
<<
"real2 d2 = c1+c3;
\n
"
;
source
<<
"real2 d3 = sign*(real2) (c1.y-c3.y, c3.x-c1.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+3*j*"
<<
m
<<
"] = d0+d2;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(3*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
4
*
L
)
<<
"], d1+d3);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(3*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
4
*
L
)
<<
"], d0-d2);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(3*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
4
*
L
)
<<
"], d1-d3);
\n
"
;
}
else
if
(
radix
==
3
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c2;
\n
"
;
source
<<
"real2 d1 = c0-0.5f*d0;
\n
"
;
source
<<
"real2 d2 = sign*"
<<
context
.
doubleToString
(
sin
(
M_PI
/
3.0
))
<<
"*(real2) (c1.y-c2.y, c2.x-c1.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+2*j*"
<<
m
<<
"] = c0+d0;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(2*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
3
*
L
)
<<
"], d1+d2);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(2*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
3
*
L
)
<<
"], d1-d2);
\n
"
;
}
else
if
(
radix
==
2
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"data"
<<
output
<<
"[base+j*"
<<
m
<<
"] = c0+c1;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
2
*
L
)
<<
"], c0-c1);
\n
"
;
}
source
<<
"}
\n
"
;
m
=
m
*
radix
;
source
<<
"barrier(CLK_LOCAL_MEM_FENCE);
\n
"
;
source
<<
"}
\n
"
;
++
stage
;
}
//
Factor zsize, generating an appropriate block of code for each factor
.
//
Create the kernel
.
while
(
L
>
1
)
{
int
input
=
stage
%
2
;
int
output
=
1
-
input
;
int
radix
;
if
(
L
%
7
==
0
)
radix
=
7
;
else
if
(
L
%
5
==
0
)
radix
=
5
;
else
if
(
L
%
4
==
0
)
radix
=
4
;
else
if
(
L
%
3
==
0
)
radix
=
3
;
else
if
(
L
%
2
==
0
)
radix
=
2
;
else
throw
OpenMMException
(
"Illegal size for FFT: "
+
context
.
intToString
(
zsize
));
source
<<
"{
\n
"
;
L
=
L
/
radix
;
source
<<
"// Pass "
<<
(
stage
+
1
)
<<
" (radix "
<<
radix
<<
")
\n
"
;
if
(
loopRequired
)
{
if
(
loopRequired
)
{
source
<<
"for (int
i
= get_local_id(0);
i
<
"
<<
(
L
*
m
)
<<
"
;
i
+= get_local_size(0))
{
\n
"
;
source
<<
"for (int
z
= get_local_id(0);
z
<
ZSIZE
;
z
+= get_local_size(0))
\n
"
;
source
<<
"
int base = i
;
\n
"
;
source
<<
"
out[y*(ZSIZE*XSIZE)+z*XSIZE+x] = data"
<<
(
stage
%
2
)
<<
"[z]
;
\n
"
;
}
}
else
{
else
{
source
<<
"if (get_local_id(0) < "
<<
(
blocksPerGroup
*
L
*
m
)
<<
") {
\n
"
;
source
<<
"if (index < XSIZE*YSIZE)
\n
"
;
source
<<
"int block = get_local_id(0)/"
<<
(
L
*
m
)
<<
";
\n
"
;
source
<<
"out[y*(ZSIZE*XSIZE)+(get_local_id(0)%ZSIZE)*XSIZE+x] = data"
<<
(
stage
%
2
)
<<
"[get_local_id(0)];
\n
"
;
source
<<
"int i = get_local_id(0)-block*"
<<
(
L
*
m
)
<<
";
\n
"
;
source
<<
"int base = i+block*"
<<
zsize
<<
";
\n
"
;
}
source
<<
"int j = i/"
<<
m
<<
";
\n
"
;
if
(
radix
==
7
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c4 = data"
<<
input
<<
"[base+"
<<
(
4
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c5 = data"
<<
input
<<
"[base+"
<<
(
5
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c6 = data"
<<
input
<<
"[base+"
<<
(
6
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c6;
\n
"
;
source
<<
"real2 d1 = c1-c6;
\n
"
;
source
<<
"real2 d2 = c2+c5;
\n
"
;
source
<<
"real2 d3 = c2-c5;
\n
"
;
source
<<
"real2 d4 = c4+c3;
\n
"
;
source
<<
"real2 d5 = c4-c3;
\n
"
;
source
<<
"real2 d6 = d2+d0;
\n
"
;
source
<<
"real2 d7 = d5+d3;
\n
"
;
source
<<
"real2 b0 = c0+d6+d4;
\n
"
;
source
<<
"real2 b1 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
+
cos
(
4
*
M_PI
/
7
)
+
cos
(
6
*
M_PI
/
7
))
/
3
-
1
)
<<
"*(d6+d4);
\n
"
;
source
<<
"real2 b2 = "
<<
context
.
doubleToString
((
2
*
cos
(
2
*
M_PI
/
7
)
-
cos
(
4
*
M_PI
/
7
)
-
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d0-d4);
\n
"
;
source
<<
"real2 b3 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
-
2
*
cos
(
4
*
M_PI
/
7
)
+
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d4-d2);
\n
"
;
source
<<
"real2 b4 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
+
cos
(
4
*
M_PI
/
7
)
-
2
*
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d2-d0);
\n
"
;
source
<<
"real2 b5 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
+
sin
(
4
*
M_PI
/
7
)
-
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d7+d1);
\n
"
;
source
<<
"real2 b6 = -sign*"
<<
context
.
doubleToString
((
2
*
sin
(
2
*
M_PI
/
7
)
-
sin
(
4
*
M_PI
/
7
)
+
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d1-d5);
\n
"
;
source
<<
"real2 b7 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
-
2
*
sin
(
4
*
M_PI
/
7
)
-
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d5-d3);
\n
"
;
source
<<
"real2 b8 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
+
sin
(
4
*
M_PI
/
7
)
+
2
*
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d3-d1);
\n
"
;
source
<<
"real2 t0 = b0+b1;
\n
"
;
source
<<
"real2 t1 = b2+b3;
\n
"
;
source
<<
"real2 t2 = b4-b3;
\n
"
;
source
<<
"real2 t3 = -b2-b4;
\n
"
;
source
<<
"real2 t4 = b6+b7;
\n
"
;
source
<<
"real2 t5 = b8-b7;
\n
"
;
source
<<
"real2 t6 = -b8-b6;
\n
"
;
source
<<
"real2 t7 = t0+t1;
\n
"
;
source
<<
"real2 t8 = t0+t2;
\n
"
;
source
<<
"real2 t9 = t0+t3;
\n
"
;
source
<<
"real2 t10 = (real2) (t4.y+b5.y, -(t4.x+b5.x));
\n
"
;
source
<<
"real2 t11 = (real2) (t5.y+b5.y, -(t5.x+b5.x));
\n
"
;
source
<<
"real2 t12 = (real2) (t6.y+b5.y, -(t6.x+b5.x));
\n
"
;
source
<<
"data"
<<
output
<<
"[base+6*j*"
<<
m
<<
"] = b0;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
7
*
L
)
<<
"], t7-t10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t9-t12);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t8+t11);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+4)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
4
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t8-t11);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+5)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
5
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t9+t12);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+6)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
6
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t7+t10);
\n
"
;
}
else
if
(
radix
==
5
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c4 = data"
<<
input
<<
"[base+"
<<
(
4
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c4;
\n
"
;
source
<<
"real2 d1 = c2+c3;
\n
"
;
source
<<
"real2 d2 = "
<<
context
.
doubleToString
(
sin
(
0.4
*
M_PI
))
<<
"*(c1-c4);
\n
"
;
source
<<
"real2 d3 = "
<<
context
.
doubleToString
(
sin
(
0.4
*
M_PI
))
<<
"*(c2-c3);
\n
"
;
source
<<
"real2 d4 = d0+d1;
\n
"
;
source
<<
"real2 d5 = "
<<
context
.
doubleToString
(
0.25
*
sqrt
(
5.0
))
<<
"*(d0-d1);
\n
"
;
source
<<
"real2 d6 = c0-0.25f*d4;
\n
"
;
source
<<
"real2 d7 = d6+d5;
\n
"
;
source
<<
"real2 d8 = d6-d5;
\n
"
;
string
coeff
=
context
.
doubleToString
(
sin
(
0.2
*
M_PI
)
/
sin
(
0.4
*
M_PI
));
source
<<
"real2 d9 = sign*(real2) (d2.y+"
<<
coeff
<<
"*d3.y, -d2.x-"
<<
coeff
<<
"*d3.x);
\n
"
;
source
<<
"real2 d10 = sign*(real2) ("
<<
coeff
<<
"*d2.y-d3.y, d3.x-"
<<
coeff
<<
"*d2.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+4*j*"
<<
m
<<
"] = c0+d4;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
5
*
L
)
<<
"], d7+d9);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d8+d10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d8-d10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+4)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
4
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d7-d9);
\n
"
;
}
}
else
if
(
radix
==
4
)
{
source
<<
"barrier(CLK_GLOBAL_MEM_FENCE);"
;
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
map
<
string
,
string
>
replacements
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
replacements
[
"XSIZE"
]
=
context
.
intToString
(
xsize
);
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
replacements
[
"YSIZE"
]
=
context
.
intToString
(
ysize
);
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
replacements
[
"ZSIZE"
]
=
context
.
intToString
(
zsize
);
source
<<
"real2 d0 = c0+c2;
\n
"
;
replacements
[
"BLOCKS_PER_GROUP"
]
=
context
.
intToString
(
blocksPerGroup
);
source
<<
"real2 d1 = c0-c2;
\n
"
;
replacements
[
"M_PI"
]
=
context
.
doubleToString
(
M_PI
);
source
<<
"real2 d2 = c1+c3;
\n
"
;
replacements
[
"COMPUTE_FFT"
]
=
source
.
str
();
source
<<
"real2 d3 = sign*(real2) (c1.y-c3.y, c3.x-c1.x);
\n
"
;
replacements
[
"LOOP_REQUIRED"
]
=
(
loopRequired
?
"1"
:
"0"
);
source
<<
"data"
<<
output
<<
"[base+3*j*"
<<
m
<<
"] = d0+d2;
\n
"
;
cl
::
Program
program
=
context
.
createProgram
(
context
.
replaceStrings
(
OpenCLKernelSources
::
fft
,
replacements
));
source
<<
"data"
<<
output
<<
"[base+(3*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
4
*
L
)
<<
"], d1+d3);
\n
"
;
cl
::
Kernel
kernel
(
program
,
"execFFT"
);
source
<<
"data"
<<
output
<<
"[base+(3*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
4
*
L
)
<<
"], d0-d2);
\n
"
;
threads
=
(
isCPU
?
1
:
blocksPerGroup
*
zsize
);
source
<<
"data"
<<
output
<<
"[base+(3*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
4
*
L
)
<<
"], d1-d3);
\n
"
;
int
kernelMaxThreads
=
kernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
());
if
(
threads
>
kernelMaxThreads
)
{
// The device can't handle this block size, so reduce it.
maxThreads
=
kernelMaxThreads
;
continue
;
}
}
else
if
(
radix
==
3
)
{
int
bufferSize
=
blocksPerGroup
*
zsize
*
(
context
.
getUseDoublePrecision
()
?
sizeof
(
mm_double2
)
:
sizeof
(
mm_float2
));
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
kernel
.
setArg
(
3
,
bufferSize
,
NULL
);
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
kernel
.
setArg
(
4
,
bufferSize
,
NULL
);
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
kernel
.
setArg
(
5
,
bufferSize
,
NULL
);
source
<<
"real2 d0 = c1+c2;
\n
"
;
return
kernel
;
source
<<
"real2 d1 = c0-0.5f*d0;
\n
"
;
source
<<
"real2 d2 = sign*"
<<
context
.
doubleToString
(
sin
(
M_PI
/
3.0
))
<<
"*(real2) (c1.y-c2.y, c2.x-c1.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+2*j*"
<<
m
<<
"] = c0+d0;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(2*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
3
*
L
)
<<
"], d1+d2);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(2*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
3
*
L
)
<<
"], d1-d2);
\n
"
;
}
else
if
(
radix
==
2
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"data"
<<
output
<<
"[base+j*"
<<
m
<<
"] = c0+c1;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
2
*
L
)
<<
"], c0-c1);
\n
"
;
}
source
<<
"}
\n
"
;
m
=
m
*
radix
;
source
<<
"barrier(CLK_LOCAL_MEM_FENCE);
\n
"
;
source
<<
"}
\n
"
;
++
stage
;
}
// Create the kernel.
if
(
loopRequired
)
{
source
<<
"for (int z = get_local_id(0); z < ZSIZE; z += get_local_size(0))
\n
"
;
source
<<
"out[y*(ZSIZE*XSIZE)+z*XSIZE+x] = data"
<<
(
stage
%
2
)
<<
"[z];
\n
"
;
}
else
{
source
<<
"if (index < XSIZE*YSIZE)
\n
"
;
source
<<
"out[y*(ZSIZE*XSIZE)+(get_local_id(0)%ZSIZE)*XSIZE+x] = data"
<<
(
stage
%
2
)
<<
"[get_local_id(0)];
\n
"
;
}
}
source
<<
"barrier(CLK_GLOBAL_MEM_FENCE);"
;
map
<
string
,
string
>
replacements
;
replacements
[
"XSIZE"
]
=
context
.
intToString
(
xsize
);
replacements
[
"YSIZE"
]
=
context
.
intToString
(
ysize
);
replacements
[
"ZSIZE"
]
=
context
.
intToString
(
zsize
);
replacements
[
"BLOCKS_PER_GROUP"
]
=
context
.
intToString
(
blocksPerGroup
);
replacements
[
"M_PI"
]
=
context
.
doubleToString
(
M_PI
);
replacements
[
"COMPUTE_FFT"
]
=
source
.
str
();
replacements
[
"LOOP_REQUIRED"
]
=
(
loopRequired
?
"1"
:
"0"
);
cl
::
Program
program
=
context
.
createProgram
(
context
.
replaceStrings
(
OpenCLKernelSources
::
fft
,
replacements
));
cl
::
Kernel
kernel
(
program
,
"execFFT"
);
int
bufferSize
=
blocksPerGroup
*
zsize
*
(
context
.
getUseDoublePrecision
()
?
sizeof
(
mm_double2
)
:
sizeof
(
mm_float2
));
kernel
.
setArg
(
3
,
bufferSize
,
NULL
);
kernel
.
setArg
(
4
,
bufferSize
,
NULL
);
kernel
.
setArg
(
5
,
bufferSize
,
NULL
);
threads
=
(
loopRequired
?
1
:
blocksPerGroup
*
zsize
);
return
kernel
;
}
}
platforms/opencl/src/OpenCLNonbondedUtilities.cpp
View file @
35b9d787
...
@@ -317,42 +317,54 @@ void OpenCLNonbondedUtilities::initialize(const System& system) {
...
@@ -317,42 +317,54 @@ void OpenCLNonbondedUtilities::initialize(const System& system) {
for
(
int
i
=
0
;
i
<
(
int
)
exclusionBlocksForBlock
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
exclusionBlocksForBlock
.
size
();
i
++
)
maxExclusions
=
(
maxExclusions
>
exclusionBlocksForBlock
[
i
].
size
()
?
maxExclusions
:
exclusionBlocksForBlock
[
i
].
size
());
maxExclusions
=
(
maxExclusions
>
exclusionBlocksForBlock
[
i
].
size
()
?
maxExclusions
:
exclusionBlocksForBlock
[
i
].
size
());
defines
[
"MAX_EXCLUSIONS"
]
=
context
.
intToString
(
maxExclusions
);
defines
[
"MAX_EXCLUSIONS"
]
=
context
.
intToString
(
maxExclusions
);
defines
[
"GROUP_SIZE"
]
=
(
deviceIsCpu
?
"32"
:
"128"
);
defines
[
"BUFFER_GROUPS"
]
=
(
deviceIsCpu
?
"4"
:
"2"
);
defines
[
"BUFFER_GROUPS"
]
=
(
deviceIsCpu
?
"4"
:
"2"
);
string
file
=
(
deviceIsCpu
?
OpenCLKernelSources
::
findInteractingBlocks_cpu
:
OpenCLKernelSources
::
findInteractingBlocks
);
string
file
=
(
deviceIsCpu
?
OpenCLKernelSources
::
findInteractingBlocks_cpu
:
OpenCLKernelSources
::
findInteractingBlocks
);
cl
::
Program
interactingBlocksProgram
=
context
.
createProgram
(
file
,
defines
);
int
groupSize
=
(
deviceIsCpu
?
32
:
128
);
findBlockBoundsKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"findBlockBounds"
);
while
(
true
)
{
findBlockBoundsKernel
.
setArg
<
cl_int
>
(
0
,
context
.
getNumAtoms
());
defines
[
"GROUP_SIZE"
]
=
context
.
intToString
(
groupSize
);
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
context
.
getPosq
().
getDeviceBuffer
());
cl
::
Program
interactingBlocksProgram
=
context
.
createProgram
(
file
,
defines
);
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
blockCenter
->
getDeviceBuffer
());
findBlockBoundsKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"findBlockBounds"
);
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
blockBoundingBox
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl_int
>
(
0
,
context
.
getNumAtoms
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
rebuildNeighborList
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
context
.
getPosq
().
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
sortedBlocks
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
blockCenter
->
getDeviceBuffer
());
sortBoxDataKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"sortBoxData"
);
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
blockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
0
,
sortedBlocks
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
rebuildNeighborList
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
1
,
blockCenter
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
sortedBlocks
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
blockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"sortBoxData"
);
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
sortedBlockCenter
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
0
,
sortedBlocks
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
sortedBlockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
1
,
blockCenter
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
context
.
getPosq
().
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
blockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
oldPositions
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
sortedBlockCenter
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
interactionCount
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
sortedBlockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
8
,
rebuildNeighborList
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
context
.
getPosq
().
getDeviceBuffer
());
findInteractingBlocksKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"findBlocksWithInteractions"
);
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
oldPositions
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
interactionCount
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
interactionCount
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
interactingTiles
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
8
,
rebuildNeighborList
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
interactingAtoms
->
getDeviceBuffer
());
findInteractingBlocksKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"findBlocksWithInteractions"
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
context
.
getPosq
().
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
interactionCount
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
6
,
interactingTiles
->
getSize
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
interactingTiles
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
7
,
startBlockIndex
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
interactingAtoms
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
8
,
numBlocks
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
context
.
getPosq
().
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
9
,
sortedBlocks
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
6
,
interactingTiles
->
getSize
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
10
,
sortedBlockCenter
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
7
,
startBlockIndex
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
11
,
sortedBlockBoundingBox
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
8
,
numBlocks
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
12
,
exclusionIndices
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
9
,
sortedBlocks
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
13
,
exclusionRowIndices
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
10
,
sortedBlockCenter
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
14
,
oldPositions
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
11
,
sortedBlockBoundingBox
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
15
,
rebuildNeighborList
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
12
,
exclusionIndices
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
13
,
exclusionRowIndices
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
14
,
oldPositions
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
15
,
rebuildNeighborList
->
getDeviceBuffer
());
if
(
findInteractingBlocksKernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
())
<
groupSize
)
{
// The device can't handle this block size, so reduce it.
groupSize
-=
32
;
if
(
groupSize
<
32
)
throw
OpenMMException
(
"Failed to create findInteractingBlocks kernel"
);
continue
;
}
break
;
}
}
}
}
}
...
...
platforms/opencl/src/OpenCLSort.cpp
View file @
35b9d787
...
@@ -56,10 +56,13 @@ OpenCLSort::OpenCLSort(OpenCLContext& context, SortTrait* trait, unsigned int le
...
@@ -56,10 +56,13 @@ OpenCLSort::OpenCLSort(OpenCLContext& context, SortTrait* trait, unsigned int le
unsigned
int
maxGroupSize
=
std
::
min
(
256
,
(
int
)
context
.
getDevice
().
getInfo
<
CL_DEVICE_MAX_WORK_GROUP_SIZE
>
());
unsigned
int
maxGroupSize
=
std
::
min
(
256
,
(
int
)
context
.
getDevice
().
getInfo
<
CL_DEVICE_MAX_WORK_GROUP_SIZE
>
());
int
maxSharedMem
=
context
.
getDevice
().
getInfo
<
CL_DEVICE_LOCAL_MEM_SIZE
>
();
int
maxSharedMem
=
context
.
getDevice
().
getInfo
<
CL_DEVICE_LOCAL_MEM_SIZE
>
();
unsigned
int
maxLocalBuffer
=
(
unsigned
int
)
((
maxSharedMem
/
trait
->
getDataSize
())
/
2
);
unsigned
int
maxLocalBuffer
=
(
unsigned
int
)
((
maxSharedMem
/
trait
->
getDataSize
())
/
2
);
isShortList
=
(
length
<=
maxLocalBuffer
);
unsigned
int
maxRangeSize
=
std
::
min
(
maxGroupSize
,
(
unsigned
int
)
computeRangeKernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
()));
for
(
rangeKernelSize
=
1
;
rangeKernelSize
*
2
<=
maxGroupSize
;
rangeKernelSize
*=
2
)
unsigned
int
maxPositionsSize
=
std
::
min
(
maxGroupSize
,
(
unsigned
int
)
computeBucketPositionsKernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
()));
unsigned
int
maxShortListSize
=
shortListKernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
());
isShortList
=
(
length
<=
maxLocalBuffer
&&
length
<
maxShortListSize
);
for
(
rangeKernelSize
=
1
;
rangeKernelSize
*
2
<=
maxRangeSize
;
rangeKernelSize
*=
2
)
;
;
positionsKernelSize
=
rangeKernelSize
;
positionsKernelSize
=
std
::
min
(
rangeKernelSize
,
maxPositionsSize
)
;
sortKernelSize
=
(
isShortList
?
rangeKernelSize
:
rangeKernelSize
/
2
);
sortKernelSize
=
(
isShortList
?
rangeKernelSize
:
rangeKernelSize
/
2
);
if
(
rangeKernelSize
>
length
)
if
(
rangeKernelSize
>
length
)
rangeKernelSize
=
length
;
rangeKernelSize
=
length
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment