OpenCLFFT3D.cpp 11.3 KB
Newer Older
Peter Eastman's avatar
Peter Eastman committed
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2009-2011 Stanford University and the Authors.      *
Peter Eastman's avatar
Peter Eastman committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLFFT3D.h"
#include "OpenCLExpressionUtilities.h"
29
#include "OpenCLKernelSources.h"
Peter Eastman's avatar
Peter Eastman committed
30
31
32
33
34
35
36
37
38
#include "../src/SimTKUtilities/SimTKOpenMMRealType.h"
#include <map>
#include <sstream>
#include <string>

using namespace OpenMM;
using namespace std;

OpenCLFFT3D::OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize) : context(context), xsize(xsize), ysize(ysize), zsize(zsize) {
Peter Eastman's avatar
Peter Eastman committed
39
40
41
    zkernel = createKernel(xsize, ysize, zsize);
    xkernel = createKernel(ysize, zsize, xsize);
    ykernel = createKernel(zsize, xsize, ysize);
Peter Eastman's avatar
Peter Eastman committed
42
43
}

Peter Eastman's avatar
Peter Eastman committed
44
void OpenCLFFT3D::execFFT(OpenCLArray<mm_float2>& in, OpenCLArray<mm_float2>& out, bool forward) {
45
    int maxSize = xkernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice());
46
47
    if (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU)
        maxSize = 1;
Peter Eastman's avatar
Peter Eastman committed
48
49
50
51
52
53
54
    zkernel.setArg<cl::Buffer>(0, in.getDeviceBuffer());
    zkernel.setArg<cl::Buffer>(1, out.getDeviceBuffer());
    zkernel.setArg<cl_float>(2, forward ? 1.0f : -1.0f);
    context.executeKernel(zkernel, xsize*ysize*zsize, min(zsize, (int) maxSize));
    xkernel.setArg<cl::Buffer>(0, out.getDeviceBuffer());
    xkernel.setArg<cl::Buffer>(1, in.getDeviceBuffer());
    xkernel.setArg<cl_float>(2, forward ? 1.0f : -1.0f);
55
    context.executeKernel(xkernel, xsize*ysize*zsize, min(xsize, (int) maxSize));
Peter Eastman's avatar
Peter Eastman committed
56
57
58
    ykernel.setArg<cl::Buffer>(0, in.getDeviceBuffer());
    ykernel.setArg<cl::Buffer>(1, out.getDeviceBuffer());
    ykernel.setArg<cl_float>(2, forward ? 1.0f : -1.0f);
59
    context.executeKernel(ykernel, xsize*ysize*zsize, min(ysize, (int) maxSize));
Peter Eastman's avatar
Peter Eastman committed
60
61
}

Peter Eastman's avatar
Peter Eastman committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
int OpenCLFFT3D::findLegalDimension(int minimum) {
    if (minimum < 1)
        return 1;
    while (true) {
        // Attempt to factor the current value.

        int unfactored = minimum;
        for (int factor = 2; factor < 6; factor++) {
            while (unfactored > 1 && unfactored%factor == 0)
                unfactored /= factor;
        }
        if (unfactored == 1)
            return minimum;
        minimum++;
    }
}

Peter Eastman's avatar
Peter Eastman committed
79
80
cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize) {
    bool loopRequired = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
Peter Eastman's avatar
Peter Eastman committed
81
82
    stringstream source;
    int stage = 0;
Peter Eastman's avatar
Peter Eastman committed
83
    int L = zsize;
Peter Eastman's avatar
Peter Eastman committed
84
    int m = 1;
Peter Eastman's avatar
Peter Eastman committed
85

Peter Eastman's avatar
Peter Eastman committed
86
    // Factor zsize, generating an appropriate block of code for each factor.
Peter Eastman's avatar
Peter Eastman committed
87

Peter Eastman's avatar
Peter Eastman committed
88
    while (L > 1) {
Peter Eastman's avatar
Peter Eastman committed
89
90
91
        int input = stage%2;
        int output = 1-input;
        source<<"{\n";
Peter Eastman's avatar
Peter Eastman committed
92
        if (L%5 == 0) {
Peter Eastman's avatar
Peter Eastman committed
93
94
            L = L/5;
            source<<"// Pass "<<(stage+1)<<" (radix 5)\n";
Peter Eastman's avatar
Peter Eastman committed
95
96
97
98
99
100
            if (loopRequired)
                source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n";
            else {
                source<<"if (get_local_id(0) < "<<(L*m)<<") {\n";
                source<<"int i = get_local_id(0);\n";
            }
Peter Eastman's avatar
Peter Eastman committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            source<<"int j = i/"<<m<<";\n";
            source<<"float2 c0 = data"<<input<<"[i];\n";
            source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
            source<<"float2 c2 = data"<<input<<"[i+"<<(2*L*m)<<"];\n";
            source<<"float2 c3 = data"<<input<<"[i+"<<(3*L*m)<<"];\n";
            source<<"float2 c4 = data"<<input<<"[i+"<<(4*L*m)<<"];\n";
            source<<"float2 d0 = c1+c4;\n";
            source<<"float2 d1 = c2+c3;\n";
            source<<"float2 d2 = "<<OpenCLExpressionUtilities::doubleToString(sin(0.4*M_PI))<<"*(c1-c4);\n";
            source<<"float2 d3 = "<<OpenCLExpressionUtilities::doubleToString(sin(0.4*M_PI))<<"*(c2-c3);\n";
            source<<"float2 d4 = d0+d1;\n";
            source<<"float2 d5 = "<<OpenCLExpressionUtilities::doubleToString(0.25*sqrt(5.0))<<"*(d0-d1);\n";
            source<<"float2 d6 = c0-0.25f*d4;\n";
            source<<"float2 d7 = d6+d5;\n";
            source<<"float2 d8 = d6-d5;\n";
            string coeff = OpenCLExpressionUtilities::doubleToString(sin(0.2*M_PI)/sin(0.4*M_PI));
            source<<"float2 d9 = sign*(float2) (d2.y+"<<coeff<<"*d3.y, -d2.x-"<<coeff<<"*d3.x);\n";
            source<<"float2 d10 = sign*(float2) ("<<coeff<<"*d2.y-d3.y, d3.x-"<<coeff<<"*d2.x);\n";
            source<<"data"<<output<<"[i+4*j*"<<m<<"] = c0+d4;\n";
Peter Eastman's avatar
Peter Eastman committed
120
121
122
123
            source<<"data"<<output<<"[i+(4*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(5*L)<<"], d7+d9);\n";
            source<<"data"<<output<<"[i+(4*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(5*L)<<"], d8+d10);\n";
            source<<"data"<<output<<"[i+(4*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(5*L)<<"], d8-d10);\n";
            source<<"data"<<output<<"[i+(4*j+4)*"<<m<<"] = multiplyComplex(w[j*"<<(4*zsize)<<"/"<<(5*L)<<"], d7-d9);\n";
124
            source<<"}\n";
Peter Eastman's avatar
Peter Eastman committed
125
126
            m = m*5;
        }
Peter Eastman's avatar
Peter Eastman committed
127
        else if (L%4 == 0) {
Peter Eastman's avatar
Peter Eastman committed
128
129
            L = L/4;
            source<<"// Pass "<<(stage+1)<<" (radix 4)\n";
Peter Eastman's avatar
Peter Eastman committed
130
131
132
133
134
135
            if (loopRequired)
                source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n";
            else {
                source<<"if (get_local_id(0) < "<<(L*m)<<") {\n";
                source<<"int i = get_local_id(0);\n";
            }
Peter Eastman's avatar
Peter Eastman committed
136
137
138
139
140
141
142
143
144
145
            source<<"int j = i/"<<m<<";\n";
            source<<"float2 c0 = data"<<input<<"[i];\n";
            source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
            source<<"float2 c2 = data"<<input<<"[i+"<<(2*L*m)<<"];\n";
            source<<"float2 c3 = data"<<input<<"[i+"<<(3*L*m)<<"];\n";
            source<<"float2 d0 = c0+c2;\n";
            source<<"float2 d1 = c0-c2;\n";
            source<<"float2 d2 = c1+c3;\n";
            source<<"float2 d3 = sign*(float2) (c1.y-c3.y, c3.x-c1.x);\n";
            source<<"data"<<output<<"[i+3*j*"<<m<<"] = d0+d2;\n";
Peter Eastman's avatar
Peter Eastman committed
146
147
148
            source<<"data"<<output<<"[i+(3*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(4*L)<<"], d1+d3);\n";
            source<<"data"<<output<<"[i+(3*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(4*L)<<"], d0-d2);\n";
            source<<"data"<<output<<"[i+(3*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(4*L)<<"], d1-d3);\n";
149
            source<<"}\n";
Peter Eastman's avatar
Peter Eastman committed
150
151
            m = m*4;
        }
Peter Eastman's avatar
Peter Eastman committed
152
        else if (L%3 == 0) {
Peter Eastman's avatar
Peter Eastman committed
153
154
            L = L/3;
            source<<"// Pass "<<(stage+1)<<" (radix 3)\n";
Peter Eastman's avatar
Peter Eastman committed
155
156
157
158
159
160
            if (loopRequired)
                source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n";
            else {
                source<<"if (get_local_id(0) < "<<(L*m)<<") {\n";
                source<<"int i = get_local_id(0);\n";
            }
Peter Eastman's avatar
Peter Eastman committed
161
162
163
164
165
166
167
168
            source<<"int j = i/"<<m<<";\n";
            source<<"float2 c0 = data"<<input<<"[i];\n";
            source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
            source<<"float2 c2 = data"<<input<<"[i+"<<(2*L*m)<<"];\n";
            source<<"float2 d0 = c1+c2;\n";
            source<<"float2 d1 = c0-0.5f*d0;\n";
            source<<"float2 d2 = sign*"<<OpenCLExpressionUtilities::doubleToString(sin(M_PI/3.0))<<"*(float2) (c1.y-c2.y, c2.x-c1.x);\n";
            source<<"data"<<output<<"[i+2*j*"<<m<<"] = c0+d0;\n";
Peter Eastman's avatar
Peter Eastman committed
169
170
            source<<"data"<<output<<"[i+(2*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(3*L)<<"], d1+d2);\n";
            source<<"data"<<output<<"[i+(2*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(3*L)<<"], d1-d2);\n";
171
            source<<"}\n";
Peter Eastman's avatar
Peter Eastman committed
172
173
            m = m*3;
        }
Peter Eastman's avatar
Peter Eastman committed
174
        else if (L%2 == 0) {
Peter Eastman's avatar
Peter Eastman committed
175
176
            L = L/2;
            source<<"// Pass "<<(stage+1)<<" (radix 2)\n";
Peter Eastman's avatar
Peter Eastman committed
177
178
179
180
181
182
            if (loopRequired)
                source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n";
            else {
                source<<"if (get_local_id(0) < "<<(L*m)<<") {\n";
                source<<"int i = get_local_id(0);\n";
            }
Peter Eastman's avatar
Peter Eastman committed
183
184
185
186
            source<<"int j = i/"<<m<<";\n";
            source<<"float2 c0 = data"<<input<<"[i];\n";
            source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
            source<<"data"<<output<<"[i+j*"<<m<<"] = c0+c1;\n";
Peter Eastman's avatar
Peter Eastman committed
187
            source<<"data"<<output<<"[i+(j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(2*L)<<"], c0-c1);\n";
188
            source<<"}\n";
Peter Eastman's avatar
Peter Eastman committed
189
190
191
            m = m*2;
        }
        else
Peter Eastman's avatar
Peter Eastman committed
192
            throw OpenMMException("Illegal size for FFT: "+OpenCLExpressionUtilities::intToString(zsize));
Peter Eastman's avatar
Peter Eastman committed
193
194
195
196
        source<<"barrier(CLK_LOCAL_MEM_FENCE);\n";
        source<<"}\n";
        ++stage;
    }
Peter Eastman's avatar
Peter Eastman committed
197
198
199

    // Create the kernel.

Peter Eastman's avatar
Peter Eastman committed
200
201
202
203
204
205
    if (loopRequired) {
        source<<"for (int z = get_local_id(0); z < ZSIZE; z += get_local_size(0))\n";
        source<<"out[y*(ZSIZE*XSIZE)+z*XSIZE+x] = data"<<(stage%2)<<"[z];\n";
    }
    else
        source<<"out[y*(ZSIZE*XSIZE)+get_local_id(0)*XSIZE+x] = data"<<(stage%2)<<"[get_local_id(0)];\n";
206
    source<<"barrier(CLK_GLOBAL_MEM_FENCE);";
Peter Eastman's avatar
Peter Eastman committed
207
208
209
210
211
    map<string, string> replacements;
    replacements["XSIZE"] = OpenCLExpressionUtilities::intToString(xsize);
    replacements["YSIZE"] = OpenCLExpressionUtilities::intToString(ysize);
    replacements["ZSIZE"] = OpenCLExpressionUtilities::intToString(zsize);
    replacements["M_PI"] = OpenCLExpressionUtilities::doubleToString(M_PI);
Peter Eastman's avatar
Peter Eastman committed
212
    replacements["COMPUTE_FFT"] = source.str();
Peter Eastman's avatar
Peter Eastman committed
213
214
    if (loopRequired)
        replacements["LOOP_REQUIRED"] = "1";
215
    cl::Program program = context.createProgram(context.replaceStrings(OpenCLKernelSources::fft, replacements));
Peter Eastman's avatar
Peter Eastman committed
216
    cl::Kernel kernel(program, "execFFT");
Peter Eastman's avatar
Peter Eastman committed
217
218
219
    kernel.setArg(3, zsize*sizeof(mm_float2), NULL);
    kernel.setArg(4, zsize*sizeof(mm_float2), NULL);
    kernel.setArg(5, zsize*sizeof(mm_float2), NULL);
Peter Eastman's avatar
Peter Eastman committed
220
221
    return kernel;
}