Commit 79b85ad7 authored by peastman's avatar peastman
Browse files

Merge pull request #870 from peastman/r2c

Implemented an optimized real-to-complex FFT
parents e531310f 803fe433
......@@ -83,13 +83,14 @@ public:
*/
static int findLegalDimension(int minimum);
private:
cl::Kernel createKernel(int xsize, int ysize, int zsize, int& threads, int axis, bool forward);
cl::Kernel createKernel(int xsize, int ysize, int zsize, int& threads, int axis, bool forward, bool inputIsReal);
int xsize, ysize, zsize;
int xthreads, ythreads, zthreads;
bool realToComplex;
bool packRealAsComplex;
OpenCLContext& context;
cl::Kernel xkernel, ykernel, zkernel;
cl::Kernel invxkernel, invykernel, invzkernel;
cl::Kernel packForwardKernel, unpackForwardKernel, packBackwardKernel, unpackBackwardKernel;
};
} // namespace OpenMM
......
......@@ -36,19 +36,97 @@ using namespace OpenMM;
using namespace std;
OpenCLFFT3D::OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize, bool realToComplex) :
context(context), xsize(xsize), ysize(ysize), zsize(zsize), realToComplex(realToComplex) {
zkernel = createKernel(xsize, ysize, zsize, zthreads, 0, true);
xkernel = createKernel(ysize, zsize, xsize, xthreads, 1, true);
ykernel = createKernel(zsize, xsize, ysize, ythreads, 2, true);
invzkernel = createKernel(xsize, ysize, zsize, zthreads, 0, false);
invxkernel = createKernel(ysize, zsize, xsize, xthreads, 1, false);
invykernel = createKernel(zsize, xsize, ysize, ythreads, 2, false);
context(context), xsize(xsize), ysize(ysize), zsize(zsize) {
packRealAsComplex = false;
int packedXSize = xsize;
int packedYSize = ysize;
int packedZSize = zsize;
if (realToComplex) {
// If any axis size is even, we can pack the real values into a complex grid that is only half as large.
// Look for an appropriate axis.
packRealAsComplex = true;
int packedAxis, bufferSize;
if (xsize%2 == 0) {
packedAxis = 0;
packedXSize /= 2;
bufferSize = packedXSize;
}
else if (ysize%2 == 0) {
packedAxis = 1;
packedYSize /= 2;
bufferSize = packedYSize;
}
else if (zsize%2 == 0) {
packedAxis = 2;
packedZSize /= 2;
bufferSize = packedZSize;
}
else
packRealAsComplex = false;
if (packRealAsComplex) {
// Build the kernels for packing and unpacking the data.
map<string, string> defines;
defines["XSIZE"] = context.intToString(xsize);
defines["YSIZE"] = context.intToString(ysize);
defines["ZSIZE"] = context.intToString(zsize);
defines["PACKED_AXIS"] = context.intToString(packedAxis);
defines["PACKED_XSIZE"] = context.intToString(packedXSize);
defines["PACKED_YSIZE"] = context.intToString(packedYSize);
defines["PACKED_ZSIZE"] = context.intToString(packedZSize);
cl::Program program = context.createProgram(OpenCLKernelSources::fftR2C, defines);
packForwardKernel = cl::Kernel(program, "packForwardData");
unpackForwardKernel = cl::Kernel(program, "unpackForwardData");
unpackForwardKernel.setArg(2, bufferSize*(context.getUseDoublePrecision() ? sizeof(mm_double2) : sizeof(mm_float2)), NULL);
packBackwardKernel = cl::Kernel(program, "packBackwardData");
packBackwardKernel.setArg(2, bufferSize*(context.getUseDoublePrecision() ? sizeof(mm_double2) : sizeof(mm_float2)), NULL);
unpackBackwardKernel = cl::Kernel(program, "unpackBackwardData");
}
}
bool inputIsReal = (realToComplex && !packRealAsComplex);
zkernel = createKernel(packedXSize, packedYSize, packedZSize, zthreads, 0, true, inputIsReal);
xkernel = createKernel(packedYSize, packedZSize, packedXSize, xthreads, 1, true, inputIsReal);
ykernel = createKernel(packedZSize, packedXSize, packedYSize, ythreads, 2, true, inputIsReal);
invzkernel = createKernel(packedXSize, packedYSize, packedZSize, zthreads, 0, false, inputIsReal);
invxkernel = createKernel(packedYSize, packedZSize, packedXSize, xthreads, 1, false, inputIsReal);
invykernel = createKernel(packedZSize, packedXSize, packedYSize, ythreads, 2, false, inputIsReal);
}
void OpenCLFFT3D::execFFT(OpenCLArray& in, OpenCLArray& out, bool forward) {
cl::Kernel kernel1 = (forward ? zkernel : invzkernel);
cl::Kernel kernel2 = (forward ? xkernel : invxkernel);
cl::Kernel kernel3 = (forward ? ykernel : invykernel);
if (packRealAsComplex) {
cl::Kernel packKernel = (forward ? packForwardKernel : packBackwardKernel);
cl::Kernel unpackKernel = (forward ? unpackForwardKernel : unpackBackwardKernel);
int gridSize = xsize*ysize*zsize/2;
// Pack the data into a half sized grid.
packKernel.setArg<cl::Buffer>(0, in.getDeviceBuffer());
packKernel.setArg<cl::Buffer>(1, out.getDeviceBuffer());
context.executeKernel(packKernel, gridSize);
// Perform the FFT.
kernel1.setArg<cl::Buffer>(0, out.getDeviceBuffer());
kernel1.setArg<cl::Buffer>(1, in.getDeviceBuffer());
context.executeKernel(kernel1, gridSize, zthreads);
kernel2.setArg<cl::Buffer>(0, in.getDeviceBuffer());
kernel2.setArg<cl::Buffer>(1, out.getDeviceBuffer());
context.executeKernel(kernel2, gridSize, xthreads);
kernel3.setArg<cl::Buffer>(0, out.getDeviceBuffer());
kernel3.setArg<cl::Buffer>(1, in.getDeviceBuffer());
context.executeKernel(kernel3, gridSize, ythreads);
// Unpack the data.
unpackKernel.setArg<cl::Buffer>(0, in.getDeviceBuffer());
unpackKernel.setArg<cl::Buffer>(1, out.getDeviceBuffer());
context.executeKernel(unpackKernel, gridSize);
}
else {
kernel1.setArg<cl::Buffer>(0, in.getDeviceBuffer());
kernel1.setArg<cl::Buffer>(1, out.getDeviceBuffer());
context.executeKernel(kernel1, xsize*ysize*zsize, zthreads);
......@@ -58,6 +136,7 @@ void OpenCLFFT3D::execFFT(OpenCLArray& in, OpenCLArray& out, bool forward) {
kernel3.setArg<cl::Buffer>(0, in.getDeviceBuffer());
kernel3.setArg<cl::Buffer>(1, out.getDeviceBuffer());
context.executeKernel(kernel3, xsize*ysize*zsize, ythreads);
}
}
int OpenCLFFT3D::findLegalDimension(int minimum) {
......@@ -77,7 +156,7 @@ int OpenCLFFT3D::findLegalDimension(int minimum) {
}
}
cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int& threads, int axis, bool forward) {
cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int& threads, int axis, bool forward, bool inputIsReal) {
int maxThreads = std::min(256, (int) context.getDevice().getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>());
bool isCPU = context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU;
while (true) {
......@@ -230,7 +309,7 @@ cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int& threa
// Create the kernel.
bool outputIsReal = (realToComplex && axis == 2 && !forward);
bool outputIsReal = (inputIsReal && axis == 2 && !forward);
string outputSuffix = (outputIsReal ? ".x" : "");
if (loopRequired) {
source<<"for (int z = get_local_id(0); z < ZSIZE; z += get_local_size(0))\n";
......@@ -249,9 +328,9 @@ cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int& threa
replacements["COMPUTE_FFT"] = source.str();
replacements["LOOP_REQUIRED"] = (loopRequired ? "1" : "0");
replacements["SIGN"] = (forward ? "1" : "-1");
replacements["INPUT_TYPE"] = (realToComplex && axis == 0 && forward ? "real" : "real2");
replacements["INPUT_TYPE"] = (inputIsReal && axis == 0 && forward ? "real" : "real2");
replacements["OUTPUT_TYPE"] = (outputIsReal ? "real" : "real2");
replacements["INPUT_IS_REAL"] = (realToComplex && axis == 0 && forward ? "1" : "0");
replacements["INPUT_IS_REAL"] = (inputIsReal && axis == 0 && forward ? "1" : "0");
cl::Program program = context.createProgram(context.replaceStrings(OpenCLKernelSources::fft, replacements));
cl::Kernel kernel(program, "execFFT");
threads = (isCPU ? 1 : blocksPerGroup*zsize);
......
/**
* Combine the two halves of a real grid into a complex grid that is half as large.
*/
__kernel void packForwardData(__global const real* restrict in, __global real2* restrict out) {
const int gridSize = PACKED_XSIZE*PACKED_YSIZE*PACKED_ZSIZE;
for (int index = get_global_id(0); index < gridSize; index += get_global_size(0)) {
int x = index/(PACKED_YSIZE*PACKED_ZSIZE);
int remainder = index-x*(PACKED_YSIZE*PACKED_ZSIZE);
int y = remainder/PACKED_ZSIZE;
int z = remainder-y*PACKED_ZSIZE;
#if PACKED_AXIS == 0
real2 value = (real2) (in[2*x*YSIZE*ZSIZE+y*ZSIZE+z], in[(2*x+1)*YSIZE*ZSIZE+y*ZSIZE+z]);
#elif PACKED_AXIS == 1
real2 value = (real2) (in[x*YSIZE*ZSIZE+2*y*ZSIZE+z], in[x*YSIZE*ZSIZE+(2*y+1)*ZSIZE+z]);
#else
real2 value = (real2) (in[x*YSIZE*ZSIZE+y*ZSIZE+2*z], in[x*YSIZE*ZSIZE+y*ZSIZE+(2*z+1)]);
#endif
out[index] = value;
}
}
/**
* Split the transformed data back into a full sized, symmetric grid.
*/
__kernel void unpackForwardData(__global const real2* restrict in, __global real2* restrict out, __local real2* restrict w) {
// Compute the phase factors.
#if PACKED_AXIS == 0
for (int i = get_local_id(0); i < PACKED_XSIZE; i += get_local_size(0))
w[i] = (real2) (sin(i*2*M_PI/XSIZE), cos(i*2*M_PI/XSIZE));
#elif PACKED_AXIS == 1
for (int i = get_local_id(0); i < PACKED_YSIZE; i += get_local_size(0))
w[i] = (real2) (sin(i*2*M_PI/YSIZE), cos(i*2*M_PI/YSIZE));
#else
for (int i = get_local_id(0); i < PACKED_ZSIZE; i += get_local_size(0))
w[i] = (real2) (sin(i*2*M_PI/ZSIZE), cos(i*2*M_PI/ZSIZE));
#endif
barrier(CLK_LOCAL_MEM_FENCE);
// Transform the data.
const int gridSize = PACKED_XSIZE*PACKED_YSIZE*PACKED_ZSIZE;
for (int index = get_global_id(0); index < gridSize; index += get_global_size(0)) {
int x = index/(PACKED_YSIZE*PACKED_ZSIZE);
int remainder = index-x*(PACKED_YSIZE*PACKED_ZSIZE);
int y = remainder/PACKED_ZSIZE;
int z = remainder-y*PACKED_ZSIZE;
int xp = (x == 0 ? 0 : PACKED_XSIZE-x);
int yp = (y == 0 ? 0 : PACKED_YSIZE-y);
int zp = (z == 0 ? 0 : PACKED_ZSIZE-z);
real2 z1 = in[x*PACKED_YSIZE*PACKED_ZSIZE+y*PACKED_ZSIZE+z];
real2 z2 = in[xp*PACKED_YSIZE*PACKED_ZSIZE+yp*PACKED_ZSIZE+zp];
#if PACKED_AXIS == 0
real2 wfac = w[x];
#elif PACKED_AXIS == 1
real2 wfac = w[y];
#else
real2 wfac = w[z];
#endif
real2 output = (real2) ((z1.x+z2.x - wfac.x*(z1.x-z2.x) + wfac.y*(z1.y+z2.y))/2, (z1.y-z2.y - wfac.y*(z1.x-z2.x) - wfac.x*(z1.y+z2.y))/2);
out[x*YSIZE*ZSIZE+y*ZSIZE+z] = output;
xp = (x == 0 ? 0 : XSIZE-x);
yp = (y == 0 ? 0 : YSIZE-y);
zp = (z == 0 ? 0 : ZSIZE-z);
#if PACKED_AXIS == 0
if (x == 0)
out[PACKED_XSIZE*YSIZE*ZSIZE+yp*ZSIZE+zp] = (real2) ((z1.x-z1.y+z2.x-z2.y)/2, (-z1.x-z1.y+z2.x+z2.y)/2);
#elif PACKED_AXIS == 1
if (y == 0)
out[xp*YSIZE*ZSIZE+PACKED_YSIZE*ZSIZE+zp] = (real2) ((z1.x-z1.y+z2.x-z2.y)/2, (-z1.x-z1.y+z2.x+z2.y)/2);
#else
if (z == 0)
out[xp*YSIZE*ZSIZE+yp*ZSIZE+PACKED_ZSIZE] = (real2) ((z1.x-z1.y+z2.x-z2.y)/2, (-z1.x-z1.y+z2.x+z2.y)/2);
#endif
else
out[xp*YSIZE*ZSIZE+yp*ZSIZE+zp] = (real2) (output.x, -output.y);
}
}
/**
* Repack the symmetric complex grid into one half as large in preparation for doing an inverse complex-to-real transform.
*/
__kernel void packBackwardData(__global const real2* restrict in, __global real2* restrict out, __local real2* restrict w) {
// Compute the phase factors.
#if PACKED_AXIS == 0
for (int i = get_local_id(0); i < PACKED_XSIZE; i += get_local_size(0))
w[i] = (real2) (cos(i*2*M_PI/XSIZE), sin(i*2*M_PI/XSIZE));
#elif PACKED_AXIS == 1
for (int i = get_local_id(0); i < PACKED_YSIZE; i += get_local_size(0))
w[i] = (real2) (cos(i*2*M_PI/YSIZE), sin(i*2*M_PI/YSIZE));
#else
for (int i = get_local_id(0); i < PACKED_ZSIZE; i += get_local_size(0))
w[i] = (real2) (cos(i*2*M_PI/ZSIZE), sin(i*2*M_PI/ZSIZE));
#endif
barrier(CLK_LOCAL_MEM_FENCE);
// Transform the data.
const int gridSize = PACKED_XSIZE*PACKED_YSIZE*PACKED_ZSIZE;
for (int index = get_global_id(0); index < gridSize; index += get_global_size(0)) {
int x = index/(PACKED_YSIZE*PACKED_ZSIZE);
int remainder = index-x*(PACKED_YSIZE*PACKED_ZSIZE);
int y = remainder/PACKED_ZSIZE;
int z = remainder-y*PACKED_ZSIZE;
int xp = (x == 0 ? 0 : PACKED_XSIZE-x);
int yp = (y == 0 ? 0 : PACKED_YSIZE-y);
int zp = (z == 0 ? 0 : PACKED_ZSIZE-z);
real2 z1 = in[x*YSIZE*ZSIZE+y*ZSIZE+z];
#if PACKED_AXIS == 0
real2 wfac = w[x];
real2 z2 = in[(PACKED_XSIZE-x)*YSIZE*ZSIZE+yp*ZSIZE+zp];
#elif PACKED_AXIS == 1
real2 wfac = w[y];
real2 z2 = in[xp*YSIZE*ZSIZE+(PACKED_YSIZE-y)*ZSIZE+zp];
#else
real2 wfac = w[z];
real2 z2 = in[xp*YSIZE*ZSIZE+yp*ZSIZE+(PACKED_ZSIZE-z)];
#endif
real2 even = (real2) ((z1.x+z2.x)/2, (z1.y-z2.y)/2);
real2 odd = (real2) ((z1.x-z2.x)/2, (z1.y+z2.y)/2);
odd = (real2) (odd.x*wfac.x-odd.y*wfac.y, odd.y*wfac.x+odd.x*wfac.y);
out[x*PACKED_YSIZE*PACKED_ZSIZE+y*PACKED_ZSIZE+z] = (real2) (even.x-odd.y, even.y+odd.x);
}
}
/**
* Split the data back into a full sized, real grid after an inverse transform.
*/
__kernel void unpackBackwardData(__global const real2* restrict in, __global real* restrict out) {
const int gridSize = PACKED_XSIZE*PACKED_YSIZE*PACKED_ZSIZE;
for (int index = get_global_id(0); index < gridSize; index += get_global_size(0)) {
int x = index/(PACKED_YSIZE*PACKED_ZSIZE);
int remainder = index-x*(PACKED_YSIZE*PACKED_ZSIZE);
int y = remainder/PACKED_ZSIZE;
int z = remainder-y*PACKED_ZSIZE;
real2 value = 2*in[index];
#if PACKED_AXIS == 0
out[2*x*YSIZE*ZSIZE+y*ZSIZE+z] = value.x;
out[(2*x+1)*YSIZE*ZSIZE+y*ZSIZE+z] = value.y;
#elif PACKED_AXIS == 1
out[x*YSIZE*ZSIZE+2*y*ZSIZE+z] = value.x;
out[x*YSIZE*ZSIZE+(2*y+1)*ZSIZE+z] = value.y;
#else
out[x*YSIZE*ZSIZE+y*ZSIZE+2*z] = value.x;
out[x*YSIZE*ZSIZE+y*ZSIZE+(2*z+1)] = value.y;
#endif
}
}
......@@ -51,7 +51,7 @@ using namespace std;
static OpenCLPlatform platform;
template <class Real2>
void testTransform(bool realToComplex) {
void testTransform(bool realToComplex, int xsize, int ysize, int zsize) {
System system;
system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false");
......@@ -59,7 +59,6 @@ void testTransform(bool realToComplex) {
context.initialize();
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
int xsize = 28, ysize = 25, zsize = 30;
vector<Real2> original(xsize*ysize*zsize);
vector<t_complex> reference(original.size());
for (int i = 0; i < (int) original.size(); i++) {
......@@ -109,12 +108,18 @@ int main(int argc, char* argv[]) {
if (argc > 1)
platform.setPropertyDefaultValue("OpenCLPrecision", string(argv[1]));
if (platform.getPropertyDefaultValue("OpenCLPrecision") == "double") {
testTransform<mm_double2>(false);
testTransform<mm_double2>(true);
testTransform<mm_double2>(false, 28, 25, 30);
testTransform<mm_double2>(true, 28, 25, 25);
testTransform<mm_double2>(true, 25, 28, 25);
testTransform<mm_double2>(true, 25, 25, 28);
testTransform<mm_double2>(true, 21, 25, 27);
}
else {
testTransform<mm_float2>(false);
testTransform<mm_float2>(true);
testTransform<mm_float2>(false, 28, 25, 30);
testTransform<mm_float2>(true, 28, 25, 25);
testTransform<mm_float2>(true, 25, 28, 25);
testTransform<mm_float2>(true, 25, 25, 28);
testTransform<mm_float2>(true, 21, 25, 27);
}
}
catch(const exception& e) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment