"platforms/opencl/tests/TestOpenCLSettle.cpp" did not exist on "efc1083e34deb7a97d37786121809c28e5c275ff"
OpenCLFFT3D.cpp 13.6 KB
Newer Older
Peter Eastman's avatar
Peter Eastman committed
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2012 Stanford University and the Authors.      *
Peter Eastman's avatar
Peter Eastman committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLFFT3D.h"
#include "OpenCLExpressionUtilities.h"
29
#include "OpenCLKernelSources.h"
30
#include "SimTKOpenMMRealType.h"
Peter Eastman's avatar
Peter Eastman committed
31
32
33
34
35
36
37
38
#include <map>
#include <sstream>
#include <string>

using namespace OpenMM;
using namespace std;

OpenCLFFT3D::OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize) : context(context), xsize(xsize), ysize(ysize), zsize(zsize) {
39
40
41
    zkernel = createKernel(xsize, ysize, zsize, zthreads);
    xkernel = createKernel(ysize, zsize, xsize, xthreads);
    ykernel = createKernel(zsize, xsize, ysize, ythreads);
Peter Eastman's avatar
Peter Eastman committed
42
43
}

44
void OpenCLFFT3D::execFFT(OpenCLArray& in, OpenCLArray& out, bool forward) {
Peter Eastman's avatar
Peter Eastman committed
45
46
    zkernel.setArg<cl::Buffer>(0, in.getDeviceBuffer());
    zkernel.setArg<cl::Buffer>(1, out.getDeviceBuffer());
47
    zkernel.setArg<cl_int>(2, forward ? 1 : -1);
48
    context.executeKernel(zkernel, xsize*ysize*zsize, zthreads);
Peter Eastman's avatar
Peter Eastman committed
49
50
    xkernel.setArg<cl::Buffer>(0, out.getDeviceBuffer());
    xkernel.setArg<cl::Buffer>(1, in.getDeviceBuffer());
51
    xkernel.setArg<cl_int>(2, forward ? 1 : -1);
52
    context.executeKernel(xkernel, xsize*ysize*zsize, xthreads);
Peter Eastman's avatar
Peter Eastman committed
53
54
    ykernel.setArg<cl::Buffer>(0, in.getDeviceBuffer());
    ykernel.setArg<cl::Buffer>(1, out.getDeviceBuffer());
55
    ykernel.setArg<cl_int>(2, forward ? 1 : -1);
56
    context.executeKernel(ykernel, xsize*ysize*zsize, ythreads);
Peter Eastman's avatar
Peter Eastman committed
57
58
}

Peter Eastman's avatar
Peter Eastman committed
59
60
61
62
63
64
65
int OpenCLFFT3D::findLegalDimension(int minimum) {
    if (minimum < 1)
        return 1;
    while (true) {
        // Attempt to factor the current value.

        int unfactored = minimum;
66
        for (int factor = 2; factor < 8; factor++) {
Peter Eastman's avatar
Peter Eastman committed
67
68
69
70
71
72
73
74
75
            while (unfactored > 1 && unfactored%factor == 0)
                unfactored /= factor;
        }
        if (unfactored == 1)
            return minimum;
        minimum++;
    }
}

76
cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int& threads) {
Peter Eastman's avatar
Peter Eastman committed
77
    bool loopRequired = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
Peter Eastman's avatar
Peter Eastman committed
78
    stringstream source;
79
    int blocksPerGroup = (loopRequired ? 1 : max(1, 256/zsize));
Peter Eastman's avatar
Peter Eastman committed
80
    int stage = 0;
Peter Eastman's avatar
Peter Eastman committed
81
    int L = zsize;
Peter Eastman's avatar
Peter Eastman committed
82
    int m = 1;
Peter Eastman's avatar
Peter Eastman committed
83

Peter Eastman's avatar
Peter Eastman committed
84
    // Factor zsize, generating an appropriate block of code for each factor.
Peter Eastman's avatar
Peter Eastman committed
85

Peter Eastman's avatar
Peter Eastman committed
86
    while (L > 1) {
Peter Eastman's avatar
Peter Eastman committed
87
88
        int input = stage%2;
        int output = 1-input;
89
90
91
92
93
94
95
96
97
98
99
100
101
        int radix;
        if (L%7 == 0)
            radix = 7;
        else if (L%5 == 0)
            radix = 5;
        else if (L%4 == 0)
            radix = 4;
        else if (L%3 == 0)
            radix = 3;
        else if (L%2 == 0)
            radix = 2;
        else
            throw OpenMMException("Illegal size for FFT: "+context.intToString(zsize));
Peter Eastman's avatar
Peter Eastman committed
102
        source<<"{\n";
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        L = L/radix;
        source<<"// Pass "<<(stage+1)<<" (radix "<<radix<<")\n";
        if (loopRequired) {
            source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n";
            source<<"int base = i;\n";
        }
        else {
            source<<"if (get_local_id(0) < "<<(blocksPerGroup*L*m)<<") {\n";
            source<<"int block = get_local_id(0)/"<<(L*m)<<";\n";
            source<<"int i = get_local_id(0)-block*"<<(L*m)<<";\n";
            source<<"int base = i+block*"<<zsize<<";\n";
        }
        source<<"int j = i/"<<m<<";\n";
        if (radix == 7) {
            source<<"real2 c0 = data"<<input<<"[base];\n";
            source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n";
            source<<"real2 c2 = data"<<input<<"[base+"<<(2*L*m)<<"];\n";
            source<<"real2 c3 = data"<<input<<"[base+"<<(3*L*m)<<"];\n";
            source<<"real2 c4 = data"<<input<<"[base+"<<(4*L*m)<<"];\n";
            source<<"real2 c5 = data"<<input<<"[base+"<<(5*L*m)<<"];\n";
            source<<"real2 c6 = data"<<input<<"[base+"<<(6*L*m)<<"];\n";
            source<<"real2 d0 = c1+c6;\n";
            source<<"real2 d1 = c1-c6;\n";
            source<<"real2 d2 = c2+c5;\n";
            source<<"real2 d3 = c2-c5;\n";
            source<<"real2 d4 = c4+c3;\n";
            source<<"real2 d5 = c4-c3;\n";
            source<<"real2 d6 = d2+d0;\n";
            source<<"real2 d7 = d5+d3;\n";
            source<<"real2 b0 = c0+d6+d4;\n";
            source<<"real2 b1 = "<<context.doubleToString((cos(2*M_PI/7)+cos(4*M_PI/7)+cos(6*M_PI/7))/3-1)<<"*(d6+d4);\n";
            source<<"real2 b2 = "<<context.doubleToString((2*cos(2*M_PI/7)-cos(4*M_PI/7)-cos(6*M_PI/7))/3)<<"*(d0-d4);\n";
            source<<"real2 b3 = "<<context.doubleToString((cos(2*M_PI/7)-2*cos(4*M_PI/7)+cos(6*M_PI/7))/3)<<"*(d4-d2);\n";
            source<<"real2 b4 = "<<context.doubleToString((cos(2*M_PI/7)+cos(4*M_PI/7)-2*cos(6*M_PI/7))/3)<<"*(d2-d0);\n";
            source<<"real2 b5 = -sign*"<<context.doubleToString((sin(2*M_PI/7)+sin(4*M_PI/7)-sin(6*M_PI/7))/3)<<"*(d7+d1);\n";
            source<<"real2 b6 = -sign*"<<context.doubleToString((2*sin(2*M_PI/7)-sin(4*M_PI/7)+sin(6*M_PI/7))/3)<<"*(d1-d5);\n";
            source<<"real2 b7 = -sign*"<<context.doubleToString((sin(2*M_PI/7)-2*sin(4*M_PI/7)-sin(6*M_PI/7))/3)<<"*(d5-d3);\n";
            source<<"real2 b8 = -sign*"<<context.doubleToString((sin(2*M_PI/7)+sin(4*M_PI/7)+2*sin(6*M_PI/7))/3)<<"*(d3-d1);\n";
            source<<"real2 t0 = b0+b1;\n";
            source<<"real2 t1 = b2+b3;\n";
            source<<"real2 t2 = b4-b3;\n";
            source<<"real2 t3 = -b2-b4;\n";
            source<<"real2 t4 = b6+b7;\n";
            source<<"real2 t5 = b8-b7;\n";
            source<<"real2 t6 = -b8-b6;\n";
            source<<"real2 t7 = t0+t1;\n";
            source<<"real2 t8 = t0+t2;\n";
            source<<"real2 t9 = t0+t3;\n";
            source<<"real2 t10 = (real2) (t4.y+b5.y, -(t4.x+b5.x));\n";
            source<<"real2 t11 = (real2) (t5.y+b5.y, -(t5.x+b5.x));\n";
            source<<"real2 t12 = (real2) (t6.y+b5.y, -(t6.x+b5.x));\n";
            source<<"data"<<output<<"[base+6*j*"<<m<<"] = b0;\n";
            source<<"data"<<output<<"[base+(6*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(7*L)<<"], t7-t10);\n";
            source<<"data"<<output<<"[base+(6*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(7*L)<<"], t9-t12);\n";
            source<<"data"<<output<<"[base+(6*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(7*L)<<"], t8+t11);\n";
            source<<"data"<<output<<"[base+(6*j+4)*"<<m<<"] = multiplyComplex(w[j*"<<(4*zsize)<<"/"<<(7*L)<<"], t8-t11);\n";
            source<<"data"<<output<<"[base+(6*j+5)*"<<m<<"] = multiplyComplex(w[j*"<<(5*zsize)<<"/"<<(7*L)<<"], t9+t12);\n";
            source<<"data"<<output<<"[base+(6*j+6)*"<<m<<"] = multiplyComplex(w[j*"<<(6*zsize)<<"/"<<(7*L)<<"], t7+t10);\n";
        }
        else if (radix == 5) {
            source<<"real2 c0 = data"<<input<<"[base];\n";
            source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n";
            source<<"real2 c2 = data"<<input<<"[base+"<<(2*L*m)<<"];\n";
            source<<"real2 c3 = data"<<input<<"[base+"<<(3*L*m)<<"];\n";
            source<<"real2 c4 = data"<<input<<"[base+"<<(4*L*m)<<"];\n";
168
169
170
171
172
173
174
175
176
177
178
179
            source<<"real2 d0 = c1+c4;\n";
            source<<"real2 d1 = c2+c3;\n";
            source<<"real2 d2 = "<<context.doubleToString(sin(0.4*M_PI))<<"*(c1-c4);\n";
            source<<"real2 d3 = "<<context.doubleToString(sin(0.4*M_PI))<<"*(c2-c3);\n";
            source<<"real2 d4 = d0+d1;\n";
            source<<"real2 d5 = "<<context.doubleToString(0.25*sqrt(5.0))<<"*(d0-d1);\n";
            source<<"real2 d6 = c0-0.25f*d4;\n";
            source<<"real2 d7 = d6+d5;\n";
            source<<"real2 d8 = d6-d5;\n";
            string coeff = context.doubleToString(sin(0.2*M_PI)/sin(0.4*M_PI));
            source<<"real2 d9 = sign*(real2) (d2.y+"<<coeff<<"*d3.y, -d2.x-"<<coeff<<"*d3.x);\n";
            source<<"real2 d10 = sign*(real2) ("<<coeff<<"*d2.y-d3.y, d3.x-"<<coeff<<"*d2.x);\n";
180
181
182
183
184
            source<<"data"<<output<<"[base+4*j*"<<m<<"] = c0+d4;\n";
            source<<"data"<<output<<"[base+(4*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(5*L)<<"], d7+d9);\n";
            source<<"data"<<output<<"[base+(4*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(5*L)<<"], d8+d10);\n";
            source<<"data"<<output<<"[base+(4*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(5*L)<<"], d8-d10);\n";
            source<<"data"<<output<<"[base+(4*j+4)*"<<m<<"] = multiplyComplex(w[j*"<<(4*zsize)<<"/"<<(5*L)<<"], d7-d9);\n";
Peter Eastman's avatar
Peter Eastman committed
185
        }
186
187
188
189
190
        else if (radix == 4) {
            source<<"real2 c0 = data"<<input<<"[base];\n";
            source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n";
            source<<"real2 c2 = data"<<input<<"[base+"<<(2*L*m)<<"];\n";
            source<<"real2 c3 = data"<<input<<"[base+"<<(3*L*m)<<"];\n";
191
192
193
194
            source<<"real2 d0 = c0+c2;\n";
            source<<"real2 d1 = c0-c2;\n";
            source<<"real2 d2 = c1+c3;\n";
            source<<"real2 d3 = sign*(real2) (c1.y-c3.y, c3.x-c1.x);\n";
195
196
197
198
            source<<"data"<<output<<"[base+3*j*"<<m<<"] = d0+d2;\n";
            source<<"data"<<output<<"[base+(3*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(4*L)<<"], d1+d3);\n";
            source<<"data"<<output<<"[base+(3*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(4*L)<<"], d0-d2);\n";
            source<<"data"<<output<<"[base+(3*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(4*L)<<"], d1-d3);\n";
Peter Eastman's avatar
Peter Eastman committed
199
        }
200
201
202
203
        else if (radix == 3) {
            source<<"real2 c0 = data"<<input<<"[base];\n";
            source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n";
            source<<"real2 c2 = data"<<input<<"[base+"<<(2*L*m)<<"];\n";
204
205
206
            source<<"real2 d0 = c1+c2;\n";
            source<<"real2 d1 = c0-0.5f*d0;\n";
            source<<"real2 d2 = sign*"<<context.doubleToString(sin(M_PI/3.0))<<"*(real2) (c1.y-c2.y, c2.x-c1.x);\n";
207
208
209
            source<<"data"<<output<<"[base+2*j*"<<m<<"] = c0+d0;\n";
            source<<"data"<<output<<"[base+(2*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(3*L)<<"], d1+d2);\n";
            source<<"data"<<output<<"[base+(2*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(3*L)<<"], d1-d2);\n";
Peter Eastman's avatar
Peter Eastman committed
210
        }
211
212
213
214
215
        else if (radix == 2) {
            source<<"real2 c0 = data"<<input<<"[base];\n";
            source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n";
            source<<"data"<<output<<"[base+j*"<<m<<"] = c0+c1;\n";
            source<<"data"<<output<<"[base+(j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(2*L)<<"], c0-c1);\n";
Peter Eastman's avatar
Peter Eastman committed
216
        }
217
218
        source<<"}\n";
        m = m*radix;
Peter Eastman's avatar
Peter Eastman committed
219
220
221
222
        source<<"barrier(CLK_LOCAL_MEM_FENCE);\n";
        source<<"}\n";
        ++stage;
    }
Peter Eastman's avatar
Peter Eastman committed
223
224
225

    // Create the kernel.

Peter Eastman's avatar
Peter Eastman committed
226
227
228
229
    if (loopRequired) {
        source<<"for (int z = get_local_id(0); z < ZSIZE; z += get_local_size(0))\n";
        source<<"out[y*(ZSIZE*XSIZE)+z*XSIZE+x] = data"<<(stage%2)<<"[z];\n";
    }
peastman's avatar
peastman committed
230
231
    else {
        source<<"if (index < XSIZE*YSIZE)\n";
232
        source<<"out[y*(ZSIZE*XSIZE)+(get_local_id(0)%ZSIZE)*XSIZE+x] = data"<<(stage%2)<<"[get_local_id(0)];\n";
peastman's avatar
peastman committed
233
    }
234
    source<<"barrier(CLK_GLOBAL_MEM_FENCE);";
Peter Eastman's avatar
Peter Eastman committed
235
    map<string, string> replacements;
236
237
238
    replacements["XSIZE"] = context.intToString(xsize);
    replacements["YSIZE"] = context.intToString(ysize);
    replacements["ZSIZE"] = context.intToString(zsize);
239
    replacements["BLOCKS_PER_GROUP"] = context.intToString(blocksPerGroup);
240
    replacements["M_PI"] = context.doubleToString(M_PI);
Peter Eastman's avatar
Peter Eastman committed
241
    replacements["COMPUTE_FFT"] = source.str();
242
    replacements["LOOP_REQUIRED"] = (loopRequired ? "1" : "0");
243
    cl::Program program = context.createProgram(context.replaceStrings(OpenCLKernelSources::fft, replacements));
Peter Eastman's avatar
Peter Eastman committed
244
    cl::Kernel kernel(program, "execFFT");
245
    int bufferSize = blocksPerGroup*zsize*(context.getUseDoublePrecision() ? sizeof(mm_double2) : sizeof(mm_float2));
246
247
248
    kernel.setArg(3, bufferSize, NULL);
    kernel.setArg(4, bufferSize, NULL);
    kernel.setArg(5, bufferSize, NULL);
249
    threads = (loopRequired ? 1 : blocksPerGroup*zsize);
Peter Eastman's avatar
Peter Eastman committed
250
251
    return kernel;
}