Commit 8316aba0 authored by Peter Eastman's avatar Peter Eastman
Browse files

Completed OpenCL FFT

parent 7901ce20
...@@ -36,29 +36,49 @@ using namespace OpenMM; ...@@ -36,29 +36,49 @@ using namespace OpenMM;
using namespace std; using namespace std;
OpenCLFFT3D::OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize) : context(context), xsize(xsize), ysize(ysize), zsize(zsize) { OpenCLFFT3D::OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize) : context(context), xsize(xsize), ysize(ysize), zsize(zsize) {
xkernel = createKernel(xsize); xkernel = createKernel(xsize, ysize, zsize, ysize*zsize, zsize, 1);
ykernel = createKernel(ysize); ykernel = createKernel(ysize, zsize, xsize, zsize, 1, ysize*zsize);
zkernel = createKernel(zsize); zkernel = createKernel(zsize, xsize, ysize, 1, ysize*zsize, zsize);
}
OpenCLFFT3D::~OpenCLFFT3D() {
} }
void OpenCLFFT3D::execFFT(OpenCLArray<mm_float2>& data, bool forward) { void OpenCLFFT3D::execFFT(OpenCLArray<mm_float2>& data, bool forward) {
xkernel.setArg<cl::Buffer>(0, data.getDeviceBuffer()); xkernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
xkernel.setArg<cl_float>(1, forward ? 1.0f : -1.0f); xkernel.setArg<cl_float>(1, forward ? 1.0f : -1.0f);
context.executeKernel(xkernel, xsize, xsize); context.executeKernel(xkernel, xsize*ysize*zsize, xsize);
ykernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
ykernel.setArg<cl_float>(1, forward ? 1.0f : -1.0f);
context.executeKernel(ykernel, xsize*ysize*zsize, ysize);
zkernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
zkernel.setArg<cl_float>(1, forward ? 1.0f : -1.0f);
context.executeKernel(zkernel, xsize*ysize*zsize, zsize);
} }
cl::Kernel OpenCLFFT3D::createKernel(int size) { int OpenCLFFT3D::findLegalDimension(int minimum) {
map<string, string> replacements; if (minimum < 1)
replacements["SIZE"] = OpenCLExpressionUtilities::intToString(size); return 1;
replacements["M_PI"] = OpenCLExpressionUtilities::doubleToString(M_PI); while (true) {
// Attempt to factor the current value.
int unfactored = minimum;
for (int factor = 2; factor < 6; factor++) {
while (unfactored > 1 && unfactored%factor == 0)
unfactored /= factor;
}
if (unfactored == 1)
return minimum;
minimum++;
}
}
cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int xmult, int ymult, int zmult) {
stringstream source; stringstream source;
int unfactored = size; int unfactored = xsize;
int stage = 0; int stage = 0;
int L = size; int L = xsize;
int m = 1; int m = 1;
// Factor xsize, generating an appropriate block of code for each factor.
while (unfactored > 1) { while (unfactored > 1) {
int input = stage%2; int input = stage%2;
int output = 1-input; int output = 1-input;
...@@ -85,10 +105,10 @@ cl::Kernel OpenCLFFT3D::createKernel(int size) { ...@@ -85,10 +105,10 @@ cl::Kernel OpenCLFFT3D::createKernel(int size) {
source<<"float2 d9 = sign*(float2) (d2.y+"<<coeff<<"*d3.y, -d2.x-"<<coeff<<"*d3.x);\n"; source<<"float2 d9 = sign*(float2) (d2.y+"<<coeff<<"*d3.y, -d2.x-"<<coeff<<"*d3.x);\n";
source<<"float2 d10 = sign*(float2) ("<<coeff<<"*d2.y-d3.y, d3.x-"<<coeff<<"*d2.x);\n"; source<<"float2 d10 = sign*(float2) ("<<coeff<<"*d2.y-d3.y, d3.x-"<<coeff<<"*d2.x);\n";
source<<"data"<<output<<"[i+4*j*"<<m<<"] = c0+d4;\n"; source<<"data"<<output<<"[i+4*j*"<<m<<"] = c0+d4;\n";
source<<"data"<<output<<"[i+(4*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<size<<"/"<<(5*L)<<"], d7+d9);\n"; source<<"data"<<output<<"[i+(4*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<xsize<<"/"<<(5*L)<<"], d7+d9);\n";
source<<"data"<<output<<"[i+(4*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*size)<<"/"<<(5*L)<<"], d8+d10);\n"; source<<"data"<<output<<"[i+(4*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*xsize)<<"/"<<(5*L)<<"], d8+d10);\n";
source<<"data"<<output<<"[i+(4*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*size)<<"/"<<(5*L)<<"], d8-d10);\n"; source<<"data"<<output<<"[i+(4*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*xsize)<<"/"<<(5*L)<<"], d8-d10);\n";
source<<"data"<<output<<"[i+(4*j+4)*"<<m<<"] = multiplyComplex(w[j*"<<(4*size)<<"/"<<(5*L)<<"], d7-d9);\n"; source<<"data"<<output<<"[i+(4*j+4)*"<<m<<"] = multiplyComplex(w[j*"<<(4*xsize)<<"/"<<(5*L)<<"], d7-d9);\n";
m = m*5; m = m*5;
unfactored /= 5; unfactored /= 5;
} }
...@@ -105,9 +125,9 @@ cl::Kernel OpenCLFFT3D::createKernel(int size) { ...@@ -105,9 +125,9 @@ cl::Kernel OpenCLFFT3D::createKernel(int size) {
source<<"float2 d2 = c1+c3;\n"; source<<"float2 d2 = c1+c3;\n";
source<<"float2 d3 = sign*(float2) (c1.y-c3.y, c3.x-c1.x);\n"; source<<"float2 d3 = sign*(float2) (c1.y-c3.y, c3.x-c1.x);\n";
source<<"data"<<output<<"[i+3*j*"<<m<<"] = d0+d2;\n"; source<<"data"<<output<<"[i+3*j*"<<m<<"] = d0+d2;\n";
source<<"data"<<output<<"[i+(3*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<size<<"/"<<(4*L)<<"], d1+d3);\n"; source<<"data"<<output<<"[i+(3*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<xsize<<"/"<<(4*L)<<"], d1+d3);\n";
source<<"data"<<output<<"[i+(3*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*size)<<"/"<<(4*L)<<"], d0-d2);\n"; source<<"data"<<output<<"[i+(3*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*xsize)<<"/"<<(4*L)<<"], d0-d2);\n";
source<<"data"<<output<<"[i+(3*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*size)<<"/"<<(4*L)<<"], d1-d3);\n"; source<<"data"<<output<<"[i+(3*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*xsize)<<"/"<<(4*L)<<"], d1-d3);\n";
m = m*4; m = m*4;
unfactored /= 4; unfactored /= 4;
} }
...@@ -122,8 +142,8 @@ cl::Kernel OpenCLFFT3D::createKernel(int size) { ...@@ -122,8 +142,8 @@ cl::Kernel OpenCLFFT3D::createKernel(int size) {
source<<"float2 d1 = c0-0.5f*d0;\n"; source<<"float2 d1 = c0-0.5f*d0;\n";
source<<"float2 d2 = sign*"<<OpenCLExpressionUtilities::doubleToString(sin(M_PI/3.0))<<"*(float2) (c1.y-c2.y, c2.x-c1.x);\n"; source<<"float2 d2 = sign*"<<OpenCLExpressionUtilities::doubleToString(sin(M_PI/3.0))<<"*(float2) (c1.y-c2.y, c2.x-c1.x);\n";
source<<"data"<<output<<"[i+2*j*"<<m<<"] = c0+d0;\n"; source<<"data"<<output<<"[i+2*j*"<<m<<"] = c0+d0;\n";
source<<"data"<<output<<"[i+(2*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<size<<"/"<<(3*L)<<"], d1+d2);\n"; source<<"data"<<output<<"[i+(2*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<xsize<<"/"<<(3*L)<<"], d1+d2);\n";
source<<"data"<<output<<"[i+(2*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*size)<<"/"<<(3*L)<<"], d1-d2);\n"; source<<"data"<<output<<"[i+(2*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*xsize)<<"/"<<(3*L)<<"], d1-d2);\n";
m = m*3; m = m*3;
unfactored /= 3; unfactored /= 3;
} }
...@@ -134,22 +154,33 @@ cl::Kernel OpenCLFFT3D::createKernel(int size) { ...@@ -134,22 +154,33 @@ cl::Kernel OpenCLFFT3D::createKernel(int size) {
source<<"float2 c0 = data"<<input<<"[i];\n"; source<<"float2 c0 = data"<<input<<"[i];\n";
source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
source<<"data"<<output<<"[i+j*"<<m<<"] = c0+c1;\n"; source<<"data"<<output<<"[i+j*"<<m<<"] = c0+c1;\n";
source<<"data"<<output<<"[i+(j+1)*"<<m<<"] = multiplyComplex(w[j*"<<size<<"/"<<(2*L)<<"], c0-c1);\n"; source<<"data"<<output<<"[i+(j+1)*"<<m<<"] = multiplyComplex(w[j*"<<xsize<<"/"<<(2*L)<<"], c0-c1);\n";
m = m*2; m = m*2;
unfactored /= 2; unfactored /= 2;
} }
else else
throw OpenMMException("Illegal size for FFT: "+OpenCLExpressionUtilities::intToString(size)); throw OpenMMException("Illegal size for FFT: "+OpenCLExpressionUtilities::intToString(xsize));
source<<"barrier(CLK_LOCAL_MEM_FENCE);\n"; source<<"barrier(CLK_LOCAL_MEM_FENCE);\n";
source<<"}\n"; source<<"}\n";
++stage; ++stage;
} }
source<<"matrix[i] = data"<<(stage%2)<<"[i];";
// Create the kernel.
source<<"matrix[element] = data"<<(stage%2)<<"[i];";
map<string, string> replacements;
replacements["XSIZE"] = OpenCLExpressionUtilities::intToString(xsize);
replacements["YSIZE"] = OpenCLExpressionUtilities::intToString(ysize);
replacements["ZSIZE"] = OpenCLExpressionUtilities::intToString(zsize);
replacements["XMULT"] = OpenCLExpressionUtilities::intToString(xmult);
replacements["YMULT"] = OpenCLExpressionUtilities::intToString(ymult);
replacements["ZMULT"] = OpenCLExpressionUtilities::intToString(zmult);
replacements["M_PI"] = OpenCLExpressionUtilities::doubleToString(M_PI);
replacements["COMPUTE_FFT"] = source.str(); replacements["COMPUTE_FFT"] = source.str();
cl::Program program = context.createProgram(context.loadSourceFromFile("fft.cl", replacements)); cl::Program program = context.createProgram(context.loadSourceFromFile("fft.cl", replacements));
cl::Kernel kernel(program, "execFFT"); cl::Kernel kernel(program, "execFFT");
kernel.setArg(2, size*sizeof(mm_float2), NULL); kernel.setArg(2, xsize*sizeof(mm_float2), NULL);
kernel.setArg(3, size*sizeof(mm_float2), NULL); kernel.setArg(3, xsize*sizeof(mm_float2), NULL);
kernel.setArg(4, size*sizeof(mm_float2), NULL); kernel.setArg(4, xsize*sizeof(mm_float2), NULL);
return kernel; return kernel;
} }
...@@ -31,13 +31,54 @@ ...@@ -31,13 +31,54 @@
namespace OpenMM { namespace OpenMM {
/**
* This class performs three dimensional Fast Fourier Transforms. It is based on the
* mixed radix algorithm described in
* <p>
* Takahashi, D. and Kanada, Y., "High-Performance Radix-2, 3 and 5 Parallel 1-D Complex
* FFT Algorithms for Distributed-Memory Parallel Computers." Journal of Supercomputing,
* 15, 207–228 (2000).
* <p>
* This class places certain restrictions on the allowed dimensions of the grid. First,
* the size of each dimension may have no prime factors other than 2, 3, and 5. You
* can call findLegalDimension() to determine the smallest size that satisfies this
* requirement and is greater than or equal to a specified minimum size. Second, the size
* of each dimension must be small enough to compute each 1D transform entirely in local
* memory with one work unit per data point. This will vary between platforms, but is
* typically at least 512.
* <p>
* Note that this class performs an unnormalized transform. That means that if you perform
* a forward transform followed immediately by an inverse transform, the effect is to
* multiply every value of the original data set by the total number of data points.
*/
class OpenCLFFT3D { class OpenCLFFT3D {
public: public:
/**
* Create an OpenCLFFT3D object for performing transforms of a particular size.
*
* @param context the context in which to perform calculations
* @param xsize the first dimension of the data sets on which FFTs will be performed
* @param ysize the second dimension of the data sets on which FFTs will be performed
* @param zsize the third dimension of the data sets on which FFTs will be performed
*/
OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize); OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize);
~OpenCLFFT3D(); /**
* Perform an in-place Fourier transform.
*
* @param data the data to transform, ordered such that data[x*ysize*zsize + y*zsize + z] contains element (x, y, z)
* @param forward true to perform a forward transform, false to perform an inverse transform
*/
void execFFT(OpenCLArray<mm_float2>& data, bool forward = true); void execFFT(OpenCLArray<mm_float2>& data, bool forward = true);
/**
* Get the smallest legal size for a dimension of the grid (that is, a size with no prime
* factors other than 2, 3, and 5).
*
* @param minimum the minimum size the return value must be greater than or equal to
*/
static int findLegalDimension(int minimum);
private: private:
cl::Kernel createKernel(int size); cl::Kernel createKernel(int xsize, int ysize, int zsize, int xmult, int ymult, int zmult);
int xsize, ysize, zsize; int xsize, ysize, zsize;
OpenCLContext& context; OpenCLContext& context;
cl::Kernel xkernel, ykernel, zkernel; cl::Kernel xkernel, ykernel, zkernel;
......
...@@ -8,8 +8,15 @@ float2 multiplyComplex(float2 c1, float2 c2) { ...@@ -8,8 +8,15 @@ float2 multiplyComplex(float2 c1, float2 c2) {
__kernel void execFFT(__global float2* matrix, float sign, __local float2* w, __local float2* data0, __local float2* data1) { __kernel void execFFT(__global float2* matrix, float sign, __local float2* w, __local float2* data0, __local float2* data1) {
const int i = get_local_id(0); const int i = get_local_id(0);
w[i] = (float2) (cos(-sign*i*2*M_PI/SIZE), sin(-sign*i*2*M_PI/SIZE)); w[i] = (float2) (cos(-sign*i*2*M_PI/XSIZE), sin(-sign*i*2*M_PI/XSIZE));
data0[i] = matrix[i]; int index = get_group_id(0);
while (index < YSIZE*ZSIZE) {
int z = index/YSIZE;
int y = index-z*YSIZE;
int element = i*XMULT+y*YMULT+z*ZMULT;
data0[i] = matrix[element];
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
COMPUTE_FFT COMPUTE_FFT
index += get_num_groups(0);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment