Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
35b9d787
Commit
35b9d787
authored
Nov 22, 2013
by
peastman
Browse files
Workarounds for devices whose maximum workgroup size varies between kernels
parent
d59a34fb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
229 additions
and
203 deletions
+229
-203
platforms/opencl/src/OpenCLFFT3D.cpp
platforms/opencl/src/OpenCLFFT3D.cpp
+177
-166
platforms/opencl/src/OpenCLNonbondedUtilities.cpp
platforms/opencl/src/OpenCLNonbondedUtilities.cpp
+46
-34
platforms/opencl/src/OpenCLSort.cpp
platforms/opencl/src/OpenCLSort.cpp
+6
-3
No files found.
platforms/opencl/src/OpenCLFFT3D.cpp
View file @
35b9d787
...
@@ -74,178 +74,189 @@ int OpenCLFFT3D::findLegalDimension(int minimum) {
...
@@ -74,178 +74,189 @@ int OpenCLFFT3D::findLegalDimension(int minimum) {
}
}
cl
::
Kernel
OpenCLFFT3D
::
createKernel
(
int
xsize
,
int
ysize
,
int
zsize
,
int
&
threads
)
{
cl
::
Kernel
OpenCLFFT3D
::
createKernel
(
int
xsize
,
int
ysize
,
int
zsize
,
int
&
threads
)
{
bool
loopRequired
=
(
context
.
getDevice
().
getInfo
<
CL_DEVICE_TYPE
>
()
==
CL_DEVICE_TYPE_CPU
);
int
maxThreads
=
std
::
min
(
256
,
(
int
)
context
.
getDevice
().
getInfo
<
CL_DEVICE_MAX_WORK_GROUP_SIZE
>
());
stringstream
source
;
bool
isCPU
=
context
.
getDevice
().
getInfo
<
CL_DEVICE_TYPE
>
()
==
CL_DEVICE_TYPE_CPU
;
int
blocksPerGroup
=
(
loopRequired
?
1
:
max
(
1
,
256
/
zsize
));
while
(
true
)
{
int
stage
=
0
;
bool
loopRequired
=
(
zsize
>
maxThreads
||
isCPU
);
int
L
=
zsize
;
stringstream
source
;
int
m
=
1
;
int
blocksPerGroup
=
(
loopRequired
?
1
:
max
(
1
,
maxThreads
/
zsize
));
int
stage
=
0
;
int
L
=
zsize
;
int
m
=
1
;
// Factor zsize, generating an appropriate block of code for each factor.
while
(
L
>
1
)
{
int
input
=
stage
%
2
;
int
output
=
1
-
input
;
int
radix
;
if
(
L
%
7
==
0
)
radix
=
7
;
else
if
(
L
%
5
==
0
)
radix
=
5
;
else
if
(
L
%
4
==
0
)
radix
=
4
;
else
if
(
L
%
3
==
0
)
radix
=
3
;
else
if
(
L
%
2
==
0
)
radix
=
2
;
else
throw
OpenMMException
(
"Illegal size for FFT: "
+
context
.
intToString
(
zsize
));
source
<<
"{
\n
"
;
L
=
L
/
radix
;
source
<<
"// Pass "
<<
(
stage
+
1
)
<<
" (radix "
<<
radix
<<
")
\n
"
;
if
(
loopRequired
)
{
source
<<
"for (int i = get_local_id(0); i < "
<<
(
L
*
m
)
<<
"; i += get_local_size(0)) {
\n
"
;
source
<<
"int base = i;
\n
"
;
}
else
{
source
<<
"if (get_local_id(0) < "
<<
(
blocksPerGroup
*
L
*
m
)
<<
") {
\n
"
;
source
<<
"int block = get_local_id(0)/"
<<
(
L
*
m
)
<<
";
\n
"
;
source
<<
"int i = get_local_id(0)-block*"
<<
(
L
*
m
)
<<
";
\n
"
;
source
<<
"int base = i+block*"
<<
zsize
<<
";
\n
"
;
}
source
<<
"int j = i/"
<<
m
<<
";
\n
"
;
if
(
radix
==
7
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c4 = data"
<<
input
<<
"[base+"
<<
(
4
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c5 = data"
<<
input
<<
"[base+"
<<
(
5
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c6 = data"
<<
input
<<
"[base+"
<<
(
6
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c6;
\n
"
;
source
<<
"real2 d1 = c1-c6;
\n
"
;
source
<<
"real2 d2 = c2+c5;
\n
"
;
source
<<
"real2 d3 = c2-c5;
\n
"
;
source
<<
"real2 d4 = c4+c3;
\n
"
;
source
<<
"real2 d5 = c4-c3;
\n
"
;
source
<<
"real2 d6 = d2+d0;
\n
"
;
source
<<
"real2 d7 = d5+d3;
\n
"
;
source
<<
"real2 b0 = c0+d6+d4;
\n
"
;
source
<<
"real2 b1 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
+
cos
(
4
*
M_PI
/
7
)
+
cos
(
6
*
M_PI
/
7
))
/
3
-
1
)
<<
"*(d6+d4);
\n
"
;
source
<<
"real2 b2 = "
<<
context
.
doubleToString
((
2
*
cos
(
2
*
M_PI
/
7
)
-
cos
(
4
*
M_PI
/
7
)
-
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d0-d4);
\n
"
;
source
<<
"real2 b3 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
-
2
*
cos
(
4
*
M_PI
/
7
)
+
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d4-d2);
\n
"
;
source
<<
"real2 b4 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
+
cos
(
4
*
M_PI
/
7
)
-
2
*
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d2-d0);
\n
"
;
source
<<
"real2 b5 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
+
sin
(
4
*
M_PI
/
7
)
-
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d7+d1);
\n
"
;
source
<<
"real2 b6 = -sign*"
<<
context
.
doubleToString
((
2
*
sin
(
2
*
M_PI
/
7
)
-
sin
(
4
*
M_PI
/
7
)
+
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d1-d5);
\n
"
;
source
<<
"real2 b7 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
-
2
*
sin
(
4
*
M_PI
/
7
)
-
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d5-d3);
\n
"
;
source
<<
"real2 b8 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
+
sin
(
4
*
M_PI
/
7
)
+
2
*
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d3-d1);
\n
"
;
source
<<
"real2 t0 = b0+b1;
\n
"
;
source
<<
"real2 t1 = b2+b3;
\n
"
;
source
<<
"real2 t2 = b4-b3;
\n
"
;
source
<<
"real2 t3 = -b2-b4;
\n
"
;
source
<<
"real2 t4 = b6+b7;
\n
"
;
source
<<
"real2 t5 = b8-b7;
\n
"
;
source
<<
"real2 t6 = -b8-b6;
\n
"
;
source
<<
"real2 t7 = t0+t1;
\n
"
;
source
<<
"real2 t8 = t0+t2;
\n
"
;
source
<<
"real2 t9 = t0+t3;
\n
"
;
source
<<
"real2 t10 = (real2) (t4.y+b5.y, -(t4.x+b5.x));
\n
"
;
source
<<
"real2 t11 = (real2) (t5.y+b5.y, -(t5.x+b5.x));
\n
"
;
source
<<
"real2 t12 = (real2) (t6.y+b5.y, -(t6.x+b5.x));
\n
"
;
source
<<
"data"
<<
output
<<
"[base+6*j*"
<<
m
<<
"] = b0;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
7
*
L
)
<<
"], t7-t10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t9-t12);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t8+t11);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+4)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
4
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t8-t11);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+5)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
5
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t9+t12);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+6)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
6
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t7+t10);
\n
"
;
}
else
if
(
radix
==
5
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c4 = data"
<<
input
<<
"[base+"
<<
(
4
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c4;
\n
"
;
source
<<
"real2 d1 = c2+c3;
\n
"
;
source
<<
"real2 d2 = "
<<
context
.
doubleToString
(
sin
(
0.4
*
M_PI
))
<<
"*(c1-c4);
\n
"
;
source
<<
"real2 d3 = "
<<
context
.
doubleToString
(
sin
(
0.4
*
M_PI
))
<<
"*(c2-c3);
\n
"
;
source
<<
"real2 d4 = d0+d1;
\n
"
;
source
<<
"real2 d5 = "
<<
context
.
doubleToString
(
0.25
*
sqrt
(
5.0
))
<<
"*(d0-d1);
\n
"
;
source
<<
"real2 d6 = c0-0.25f*d4;
\n
"
;
source
<<
"real2 d7 = d6+d5;
\n
"
;
source
<<
"real2 d8 = d6-d5;
\n
"
;
string
coeff
=
context
.
doubleToString
(
sin
(
0.2
*
M_PI
)
/
sin
(
0.4
*
M_PI
));
source
<<
"real2 d9 = sign*(real2) (d2.y+"
<<
coeff
<<
"*d3.y, -d2.x-"
<<
coeff
<<
"*d3.x);
\n
"
;
source
<<
"real2 d10 = sign*(real2) ("
<<
coeff
<<
"*d2.y-d3.y, d3.x-"
<<
coeff
<<
"*d2.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+4*j*"
<<
m
<<
"] = c0+d4;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
5
*
L
)
<<
"], d7+d9);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d8+d10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d8-d10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+4)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
4
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d7-d9);
\n
"
;
}
else
if
(
radix
==
4
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c0+c2;
\n
"
;
source
<<
"real2 d1 = c0-c2;
\n
"
;
source
<<
"real2 d2 = c1+c3;
\n
"
;
source
<<
"real2 d3 = sign*(real2) (c1.y-c3.y, c3.x-c1.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+3*j*"
<<
m
<<
"] = d0+d2;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(3*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
4
*
L
)
<<
"], d1+d3);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(3*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
4
*
L
)
<<
"], d0-d2);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(3*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
4
*
L
)
<<
"], d1-d3);
\n
"
;
}
else
if
(
radix
==
3
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c2;
\n
"
;
source
<<
"real2 d1 = c0-0.5f*d0;
\n
"
;
source
<<
"real2 d2 = sign*"
<<
context
.
doubleToString
(
sin
(
M_PI
/
3.0
))
<<
"*(real2) (c1.y-c2.y, c2.x-c1.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+2*j*"
<<
m
<<
"] = c0+d0;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(2*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
3
*
L
)
<<
"], d1+d2);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(2*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
3
*
L
)
<<
"], d1-d2);
\n
"
;
}
else
if
(
radix
==
2
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"data"
<<
output
<<
"[base+j*"
<<
m
<<
"] = c0+c1;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
2
*
L
)
<<
"], c0-c1);
\n
"
;
}
source
<<
"}
\n
"
;
m
=
m
*
radix
;
source
<<
"barrier(CLK_LOCAL_MEM_FENCE);
\n
"
;
source
<<
"}
\n
"
;
++
stage
;
}
//
Factor zsize, generating an appropriate block of code for each factor
.
//
Create the kernel
.
while
(
L
>
1
)
{
int
input
=
stage
%
2
;
int
output
=
1
-
input
;
int
radix
;
if
(
L
%
7
==
0
)
radix
=
7
;
else
if
(
L
%
5
==
0
)
radix
=
5
;
else
if
(
L
%
4
==
0
)
radix
=
4
;
else
if
(
L
%
3
==
0
)
radix
=
3
;
else
if
(
L
%
2
==
0
)
radix
=
2
;
else
throw
OpenMMException
(
"Illegal size for FFT: "
+
context
.
intToString
(
zsize
));
source
<<
"{
\n
"
;
L
=
L
/
radix
;
source
<<
"// Pass "
<<
(
stage
+
1
)
<<
" (radix "
<<
radix
<<
")
\n
"
;
if
(
loopRequired
)
{
if
(
loopRequired
)
{
source
<<
"for (int
i
= get_local_id(0);
i
<
"
<<
(
L
*
m
)
<<
"
;
i
+= get_local_size(0))
{
\n
"
;
source
<<
"for (int
z
= get_local_id(0);
z
<
ZSIZE
;
z
+= get_local_size(0))
\n
"
;
source
<<
"
int base = i
;
\n
"
;
source
<<
"
out[y*(ZSIZE*XSIZE)+z*XSIZE+x] = data"
<<
(
stage
%
2
)
<<
"[z]
;
\n
"
;
}
}
else
{
else
{
source
<<
"if (get_local_id(0) < "
<<
(
blocksPerGroup
*
L
*
m
)
<<
") {
\n
"
;
source
<<
"if (index < XSIZE*YSIZE)
\n
"
;
source
<<
"int block = get_local_id(0)/"
<<
(
L
*
m
)
<<
";
\n
"
;
source
<<
"out[y*(ZSIZE*XSIZE)+(get_local_id(0)%ZSIZE)*XSIZE+x] = data"
<<
(
stage
%
2
)
<<
"[get_local_id(0)];
\n
"
;
source
<<
"int i = get_local_id(0)-block*"
<<
(
L
*
m
)
<<
";
\n
"
;
source
<<
"int base = i+block*"
<<
zsize
<<
";
\n
"
;
}
source
<<
"int j = i/"
<<
m
<<
";
\n
"
;
if
(
radix
==
7
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c4 = data"
<<
input
<<
"[base+"
<<
(
4
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c5 = data"
<<
input
<<
"[base+"
<<
(
5
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c6 = data"
<<
input
<<
"[base+"
<<
(
6
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c6;
\n
"
;
source
<<
"real2 d1 = c1-c6;
\n
"
;
source
<<
"real2 d2 = c2+c5;
\n
"
;
source
<<
"real2 d3 = c2-c5;
\n
"
;
source
<<
"real2 d4 = c4+c3;
\n
"
;
source
<<
"real2 d5 = c4-c3;
\n
"
;
source
<<
"real2 d6 = d2+d0;
\n
"
;
source
<<
"real2 d7 = d5+d3;
\n
"
;
source
<<
"real2 b0 = c0+d6+d4;
\n
"
;
source
<<
"real2 b1 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
+
cos
(
4
*
M_PI
/
7
)
+
cos
(
6
*
M_PI
/
7
))
/
3
-
1
)
<<
"*(d6+d4);
\n
"
;
source
<<
"real2 b2 = "
<<
context
.
doubleToString
((
2
*
cos
(
2
*
M_PI
/
7
)
-
cos
(
4
*
M_PI
/
7
)
-
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d0-d4);
\n
"
;
source
<<
"real2 b3 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
-
2
*
cos
(
4
*
M_PI
/
7
)
+
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d4-d2);
\n
"
;
source
<<
"real2 b4 = "
<<
context
.
doubleToString
((
cos
(
2
*
M_PI
/
7
)
+
cos
(
4
*
M_PI
/
7
)
-
2
*
cos
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d2-d0);
\n
"
;
source
<<
"real2 b5 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
+
sin
(
4
*
M_PI
/
7
)
-
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d7+d1);
\n
"
;
source
<<
"real2 b6 = -sign*"
<<
context
.
doubleToString
((
2
*
sin
(
2
*
M_PI
/
7
)
-
sin
(
4
*
M_PI
/
7
)
+
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d1-d5);
\n
"
;
source
<<
"real2 b7 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
-
2
*
sin
(
4
*
M_PI
/
7
)
-
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d5-d3);
\n
"
;
source
<<
"real2 b8 = -sign*"
<<
context
.
doubleToString
((
sin
(
2
*
M_PI
/
7
)
+
sin
(
4
*
M_PI
/
7
)
+
2
*
sin
(
6
*
M_PI
/
7
))
/
3
)
<<
"*(d3-d1);
\n
"
;
source
<<
"real2 t0 = b0+b1;
\n
"
;
source
<<
"real2 t1 = b2+b3;
\n
"
;
source
<<
"real2 t2 = b4-b3;
\n
"
;
source
<<
"real2 t3 = -b2-b4;
\n
"
;
source
<<
"real2 t4 = b6+b7;
\n
"
;
source
<<
"real2 t5 = b8-b7;
\n
"
;
source
<<
"real2 t6 = -b8-b6;
\n
"
;
source
<<
"real2 t7 = t0+t1;
\n
"
;
source
<<
"real2 t8 = t0+t2;
\n
"
;
source
<<
"real2 t9 = t0+t3;
\n
"
;
source
<<
"real2 t10 = (real2) (t4.y+b5.y, -(t4.x+b5.x));
\n
"
;
source
<<
"real2 t11 = (real2) (t5.y+b5.y, -(t5.x+b5.x));
\n
"
;
source
<<
"real2 t12 = (real2) (t6.y+b5.y, -(t6.x+b5.x));
\n
"
;
source
<<
"data"
<<
output
<<
"[base+6*j*"
<<
m
<<
"] = b0;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
7
*
L
)
<<
"], t7-t10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t9-t12);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t8+t11);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+4)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
4
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t8-t11);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+5)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
5
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t9+t12);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(6*j+6)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
6
*
zsize
)
<<
"/"
<<
(
7
*
L
)
<<
"], t7+t10);
\n
"
;
}
else
if
(
radix
==
5
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 c4 = data"
<<
input
<<
"[base+"
<<
(
4
*
L
*
m
)
<<
"];
\n
"
;
source
<<
"real2 d0 = c1+c4;
\n
"
;
source
<<
"real2 d1 = c2+c3;
\n
"
;
source
<<
"real2 d2 = "
<<
context
.
doubleToString
(
sin
(
0.4
*
M_PI
))
<<
"*(c1-c4);
\n
"
;
source
<<
"real2 d3 = "
<<
context
.
doubleToString
(
sin
(
0.4
*
M_PI
))
<<
"*(c2-c3);
\n
"
;
source
<<
"real2 d4 = d0+d1;
\n
"
;
source
<<
"real2 d5 = "
<<
context
.
doubleToString
(
0.25
*
sqrt
(
5.0
))
<<
"*(d0-d1);
\n
"
;
source
<<
"real2 d6 = c0-0.25f*d4;
\n
"
;
source
<<
"real2 d7 = d6+d5;
\n
"
;
source
<<
"real2 d8 = d6-d5;
\n
"
;
string
coeff
=
context
.
doubleToString
(
sin
(
0.2
*
M_PI
)
/
sin
(
0.4
*
M_PI
));
source
<<
"real2 d9 = sign*(real2) (d2.y+"
<<
coeff
<<
"*d3.y, -d2.x-"
<<
coeff
<<
"*d3.x);
\n
"
;
source
<<
"real2 d10 = sign*(real2) ("
<<
coeff
<<
"*d2.y-d3.y, d3.x-"
<<
coeff
<<
"*d2.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+4*j*"
<<
m
<<
"] = c0+d4;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
5
*
L
)
<<
"], d7+d9);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d8+d10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d8-d10);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(4*j+4)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
4
*
zsize
)
<<
"/"
<<
(
5
*
L
)
<<
"], d7-d9);
\n
"
;
}
}
else
if
(
radix
==
4
)
{
source
<<
"barrier(CLK_GLOBAL_MEM_FENCE);"
;
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
map
<
string
,
string
>
replacements
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
replacements
[
"XSIZE"
]
=
context
.
intToString
(
xsize
);
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
replacements
[
"YSIZE"
]
=
context
.
intToString
(
ysize
);
source
<<
"real2 c3 = data"
<<
input
<<
"[base+"
<<
(
3
*
L
*
m
)
<<
"];
\n
"
;
replacements
[
"ZSIZE"
]
=
context
.
intToString
(
zsize
);
source
<<
"real2 d0 = c0+c2;
\n
"
;
replacements
[
"BLOCKS_PER_GROUP"
]
=
context
.
intToString
(
blocksPerGroup
);
source
<<
"real2 d1 = c0-c2;
\n
"
;
replacements
[
"M_PI"
]
=
context
.
doubleToString
(
M_PI
);
source
<<
"real2 d2 = c1+c3;
\n
"
;
replacements
[
"COMPUTE_FFT"
]
=
source
.
str
();
source
<<
"real2 d3 = sign*(real2) (c1.y-c3.y, c3.x-c1.x);
\n
"
;
replacements
[
"LOOP_REQUIRED"
]
=
(
loopRequired
?
"1"
:
"0"
);
source
<<
"data"
<<
output
<<
"[base+3*j*"
<<
m
<<
"] = d0+d2;
\n
"
;
cl
::
Program
program
=
context
.
createProgram
(
context
.
replaceStrings
(
OpenCLKernelSources
::
fft
,
replacements
));
source
<<
"data"
<<
output
<<
"[base+(3*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
4
*
L
)
<<
"], d1+d3);
\n
"
;
cl
::
Kernel
kernel
(
program
,
"execFFT"
);
source
<<
"data"
<<
output
<<
"[base+(3*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
4
*
L
)
<<
"], d0-d2);
\n
"
;
threads
=
(
isCPU
?
1
:
blocksPerGroup
*
zsize
);
source
<<
"data"
<<
output
<<
"[base+(3*j+3)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
3
*
zsize
)
<<
"/"
<<
(
4
*
L
)
<<
"], d1-d3);
\n
"
;
int
kernelMaxThreads
=
kernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
());
if
(
threads
>
kernelMaxThreads
)
{
// The device can't handle this block size, so reduce it.
maxThreads
=
kernelMaxThreads
;
continue
;
}
}
else
if
(
radix
==
3
)
{
int
bufferSize
=
blocksPerGroup
*
zsize
*
(
context
.
getUseDoublePrecision
()
?
sizeof
(
mm_double2
)
:
sizeof
(
mm_float2
));
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
kernel
.
setArg
(
3
,
bufferSize
,
NULL
);
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
kernel
.
setArg
(
4
,
bufferSize
,
NULL
);
source
<<
"real2 c2 = data"
<<
input
<<
"[base+"
<<
(
2
*
L
*
m
)
<<
"];
\n
"
;
kernel
.
setArg
(
5
,
bufferSize
,
NULL
);
source
<<
"real2 d0 = c1+c2;
\n
"
;
return
kernel
;
source
<<
"real2 d1 = c0-0.5f*d0;
\n
"
;
source
<<
"real2 d2 = sign*"
<<
context
.
doubleToString
(
sin
(
M_PI
/
3.0
))
<<
"*(real2) (c1.y-c2.y, c2.x-c1.x);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+2*j*"
<<
m
<<
"] = c0+d0;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(2*j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
3
*
L
)
<<
"], d1+d2);
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(2*j+2)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
(
2
*
zsize
)
<<
"/"
<<
(
3
*
L
)
<<
"], d1-d2);
\n
"
;
}
else
if
(
radix
==
2
)
{
source
<<
"real2 c0 = data"
<<
input
<<
"[base];
\n
"
;
source
<<
"real2 c1 = data"
<<
input
<<
"[base+"
<<
(
L
*
m
)
<<
"];
\n
"
;
source
<<
"data"
<<
output
<<
"[base+j*"
<<
m
<<
"] = c0+c1;
\n
"
;
source
<<
"data"
<<
output
<<
"[base+(j+1)*"
<<
m
<<
"] = multiplyComplex(w[j*"
<<
zsize
<<
"/"
<<
(
2
*
L
)
<<
"], c0-c1);
\n
"
;
}
source
<<
"}
\n
"
;
m
=
m
*
radix
;
source
<<
"barrier(CLK_LOCAL_MEM_FENCE);
\n
"
;
source
<<
"}
\n
"
;
++
stage
;
}
// Create the kernel.
if
(
loopRequired
)
{
source
<<
"for (int z = get_local_id(0); z < ZSIZE; z += get_local_size(0))
\n
"
;
source
<<
"out[y*(ZSIZE*XSIZE)+z*XSIZE+x] = data"
<<
(
stage
%
2
)
<<
"[z];
\n
"
;
}
else
{
source
<<
"if (index < XSIZE*YSIZE)
\n
"
;
source
<<
"out[y*(ZSIZE*XSIZE)+(get_local_id(0)%ZSIZE)*XSIZE+x] = data"
<<
(
stage
%
2
)
<<
"[get_local_id(0)];
\n
"
;
}
}
source
<<
"barrier(CLK_GLOBAL_MEM_FENCE);"
;
map
<
string
,
string
>
replacements
;
replacements
[
"XSIZE"
]
=
context
.
intToString
(
xsize
);
replacements
[
"YSIZE"
]
=
context
.
intToString
(
ysize
);
replacements
[
"ZSIZE"
]
=
context
.
intToString
(
zsize
);
replacements
[
"BLOCKS_PER_GROUP"
]
=
context
.
intToString
(
blocksPerGroup
);
replacements
[
"M_PI"
]
=
context
.
doubleToString
(
M_PI
);
replacements
[
"COMPUTE_FFT"
]
=
source
.
str
();
replacements
[
"LOOP_REQUIRED"
]
=
(
loopRequired
?
"1"
:
"0"
);
cl
::
Program
program
=
context
.
createProgram
(
context
.
replaceStrings
(
OpenCLKernelSources
::
fft
,
replacements
));
cl
::
Kernel
kernel
(
program
,
"execFFT"
);
int
bufferSize
=
blocksPerGroup
*
zsize
*
(
context
.
getUseDoublePrecision
()
?
sizeof
(
mm_double2
)
:
sizeof
(
mm_float2
));
kernel
.
setArg
(
3
,
bufferSize
,
NULL
);
kernel
.
setArg
(
4
,
bufferSize
,
NULL
);
kernel
.
setArg
(
5
,
bufferSize
,
NULL
);
threads
=
(
loopRequired
?
1
:
blocksPerGroup
*
zsize
);
return
kernel
;
}
}
platforms/opencl/src/OpenCLNonbondedUtilities.cpp
View file @
35b9d787
...
@@ -317,42 +317,54 @@ void OpenCLNonbondedUtilities::initialize(const System& system) {
...
@@ -317,42 +317,54 @@ void OpenCLNonbondedUtilities::initialize(const System& system) {
for
(
int
i
=
0
;
i
<
(
int
)
exclusionBlocksForBlock
.
size
();
i
++
)
for
(
int
i
=
0
;
i
<
(
int
)
exclusionBlocksForBlock
.
size
();
i
++
)
maxExclusions
=
(
maxExclusions
>
exclusionBlocksForBlock
[
i
].
size
()
?
maxExclusions
:
exclusionBlocksForBlock
[
i
].
size
());
maxExclusions
=
(
maxExclusions
>
exclusionBlocksForBlock
[
i
].
size
()
?
maxExclusions
:
exclusionBlocksForBlock
[
i
].
size
());
defines
[
"MAX_EXCLUSIONS"
]
=
context
.
intToString
(
maxExclusions
);
defines
[
"MAX_EXCLUSIONS"
]
=
context
.
intToString
(
maxExclusions
);
defines
[
"GROUP_SIZE"
]
=
(
deviceIsCpu
?
"32"
:
"128"
);
defines
[
"BUFFER_GROUPS"
]
=
(
deviceIsCpu
?
"4"
:
"2"
);
defines
[
"BUFFER_GROUPS"
]
=
(
deviceIsCpu
?
"4"
:
"2"
);
string
file
=
(
deviceIsCpu
?
OpenCLKernelSources
::
findInteractingBlocks_cpu
:
OpenCLKernelSources
::
findInteractingBlocks
);
string
file
=
(
deviceIsCpu
?
OpenCLKernelSources
::
findInteractingBlocks_cpu
:
OpenCLKernelSources
::
findInteractingBlocks
);
cl
::
Program
interactingBlocksProgram
=
context
.
createProgram
(
file
,
defines
);
int
groupSize
=
(
deviceIsCpu
?
32
:
128
);
findBlockBoundsKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"findBlockBounds"
);
while
(
true
)
{
findBlockBoundsKernel
.
setArg
<
cl_int
>
(
0
,
context
.
getNumAtoms
());
defines
[
"GROUP_SIZE"
]
=
context
.
intToString
(
groupSize
);
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
context
.
getPosq
().
getDeviceBuffer
());
cl
::
Program
interactingBlocksProgram
=
context
.
createProgram
(
file
,
defines
);
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
blockCenter
->
getDeviceBuffer
());
findBlockBoundsKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"findBlockBounds"
);
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
blockBoundingBox
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl_int
>
(
0
,
context
.
getNumAtoms
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
rebuildNeighborList
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
context
.
getPosq
().
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
sortedBlocks
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
blockCenter
->
getDeviceBuffer
());
sortBoxDataKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"sortBoxData"
);
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
blockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
0
,
sortedBlocks
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
rebuildNeighborList
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
1
,
blockCenter
->
getDeviceBuffer
());
findBlockBoundsKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
sortedBlocks
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
blockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"sortBoxData"
);
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
sortedBlockCenter
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
0
,
sortedBlocks
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
sortedBlockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
1
,
blockCenter
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
context
.
getPosq
().
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
blockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
oldPositions
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
sortedBlockCenter
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
interactionCount
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
sortedBlockBoundingBox
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
8
,
rebuildNeighborList
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
context
.
getPosq
().
getDeviceBuffer
());
findInteractingBlocksKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"findBlocksWithInteractions"
);
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
6
,
oldPositions
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
interactionCount
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
7
,
interactionCount
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
interactingTiles
->
getDeviceBuffer
());
sortBoxDataKernel
.
setArg
<
cl
::
Buffer
>
(
8
,
rebuildNeighborList
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
interactingAtoms
->
getDeviceBuffer
());
findInteractingBlocksKernel
=
cl
::
Kernel
(
interactingBlocksProgram
,
"findBlocksWithInteractions"
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
context
.
getPosq
().
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
2
,
interactionCount
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
6
,
interactingTiles
->
getSize
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
3
,
interactingTiles
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
7
,
startBlockIndex
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
4
,
interactingAtoms
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
8
,
numBlocks
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
5
,
context
.
getPosq
().
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
9
,
sortedBlocks
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
6
,
interactingTiles
->
getSize
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
10
,
sortedBlockCenter
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
7
,
startBlockIndex
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
11
,
sortedBlockBoundingBox
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl_uint
>
(
8
,
numBlocks
);
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
12
,
exclusionIndices
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
9
,
sortedBlocks
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
13
,
exclusionRowIndices
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
10
,
sortedBlockCenter
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
14
,
oldPositions
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
11
,
sortedBlockBoundingBox
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
15
,
rebuildNeighborList
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
12
,
exclusionIndices
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
13
,
exclusionRowIndices
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
14
,
oldPositions
->
getDeviceBuffer
());
findInteractingBlocksKernel
.
setArg
<
cl
::
Buffer
>
(
15
,
rebuildNeighborList
->
getDeviceBuffer
());
if
(
findInteractingBlocksKernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
())
<
groupSize
)
{
// The device can't handle this block size, so reduce it.
groupSize
-=
32
;
if
(
groupSize
<
32
)
throw
OpenMMException
(
"Failed to create findInteractingBlocks kernel"
);
continue
;
}
break
;
}
}
}
}
}
...
...
platforms/opencl/src/OpenCLSort.cpp
View file @
35b9d787
...
@@ -56,10 +56,13 @@ OpenCLSort::OpenCLSort(OpenCLContext& context, SortTrait* trait, unsigned int le
...
@@ -56,10 +56,13 @@ OpenCLSort::OpenCLSort(OpenCLContext& context, SortTrait* trait, unsigned int le
unsigned
int
maxGroupSize
=
std
::
min
(
256
,
(
int
)
context
.
getDevice
().
getInfo
<
CL_DEVICE_MAX_WORK_GROUP_SIZE
>
());
unsigned
int
maxGroupSize
=
std
::
min
(
256
,
(
int
)
context
.
getDevice
().
getInfo
<
CL_DEVICE_MAX_WORK_GROUP_SIZE
>
());
int
maxSharedMem
=
context
.
getDevice
().
getInfo
<
CL_DEVICE_LOCAL_MEM_SIZE
>
();
int
maxSharedMem
=
context
.
getDevice
().
getInfo
<
CL_DEVICE_LOCAL_MEM_SIZE
>
();
unsigned
int
maxLocalBuffer
=
(
unsigned
int
)
((
maxSharedMem
/
trait
->
getDataSize
())
/
2
);
unsigned
int
maxLocalBuffer
=
(
unsigned
int
)
((
maxSharedMem
/
trait
->
getDataSize
())
/
2
);
isShortList
=
(
length
<=
maxLocalBuffer
);
unsigned
int
maxRangeSize
=
std
::
min
(
maxGroupSize
,
(
unsigned
int
)
computeRangeKernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
()));
for
(
rangeKernelSize
=
1
;
rangeKernelSize
*
2
<=
maxGroupSize
;
rangeKernelSize
*=
2
)
unsigned
int
maxPositionsSize
=
std
::
min
(
maxGroupSize
,
(
unsigned
int
)
computeBucketPositionsKernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
()));
unsigned
int
maxShortListSize
=
shortListKernel
.
getWorkGroupInfo
<
CL_KERNEL_WORK_GROUP_SIZE
>
(
context
.
getDevice
());
isShortList
=
(
length
<=
maxLocalBuffer
&&
length
<
maxShortListSize
);
for
(
rangeKernelSize
=
1
;
rangeKernelSize
*
2
<=
maxRangeSize
;
rangeKernelSize
*=
2
)
;
;
positionsKernelSize
=
rangeKernelSize
;
positionsKernelSize
=
std
::
min
(
rangeKernelSize
,
maxPositionsSize
)
;
sortKernelSize
=
(
isShortList
?
rangeKernelSize
:
rangeKernelSize
/
2
);
sortKernelSize
=
(
isShortList
?
rangeKernelSize
:
rangeKernelSize
/
2
);
if
(
rangeKernelSize
>
length
)
if
(
rangeKernelSize
>
length
)
rangeKernelSize
=
length
;
rangeKernelSize
=
length
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment