Commit c08d8a53 authored by Peter Eastman's avatar Peter Eastman
Browse files

Optimizations to FFT

parent e85c741b
...@@ -36,24 +36,27 @@ using namespace OpenMM; ...@@ -36,24 +36,27 @@ using namespace OpenMM;
using namespace std; using namespace std;
OpenCLFFT3D::OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize) : context(context), xsize(xsize), ysize(ysize), zsize(zsize) { OpenCLFFT3D::OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize) : context(context), xsize(xsize), ysize(ysize), zsize(zsize) {
xkernel = createKernel(xsize, ysize, zsize, ysize*zsize, zsize, 1); zkernel = createKernel(xsize, ysize, zsize);
ykernel = createKernel(ysize, zsize, xsize, zsize, 1, ysize*zsize); xkernel = createKernel(ysize, zsize, xsize);
zkernel = createKernel(zsize, xsize, ysize, 1, ysize*zsize, zsize); ykernel = createKernel(zsize, xsize, ysize);
} }
void OpenCLFFT3D::execFFT(OpenCLArray<mm_float2>& data, bool forward) { void OpenCLFFT3D::execFFT(OpenCLArray<mm_float2>& in, OpenCLArray<mm_float2>& out, bool forward) {
int maxSize = xkernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice()); int maxSize = xkernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(context.getDevice());
if (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU) if (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU)
maxSize = 1; maxSize = 1;
xkernel.setArg<cl::Buffer>(0, data.getDeviceBuffer()); zkernel.setArg<cl::Buffer>(0, in.getDeviceBuffer());
xkernel.setArg<cl_float>(1, forward ? 1.0f : -1.0f); zkernel.setArg<cl::Buffer>(1, out.getDeviceBuffer());
zkernel.setArg<cl_float>(2, forward ? 1.0f : -1.0f);
context.executeKernel(zkernel, xsize*ysize*zsize, min(zsize, (int) maxSize));
xkernel.setArg<cl::Buffer>(0, out.getDeviceBuffer());
xkernel.setArg<cl::Buffer>(1, in.getDeviceBuffer());
xkernel.setArg<cl_float>(2, forward ? 1.0f : -1.0f);
context.executeKernel(xkernel, xsize*ysize*zsize, min(xsize, (int) maxSize)); context.executeKernel(xkernel, xsize*ysize*zsize, min(xsize, (int) maxSize));
ykernel.setArg<cl::Buffer>(0, data.getDeviceBuffer()); ykernel.setArg<cl::Buffer>(0, in.getDeviceBuffer());
ykernel.setArg<cl_float>(1, forward ? 1.0f : -1.0f); ykernel.setArg<cl::Buffer>(1, out.getDeviceBuffer());
ykernel.setArg<cl_float>(2, forward ? 1.0f : -1.0f);
context.executeKernel(ykernel, xsize*ysize*zsize, min(ysize, (int) maxSize)); context.executeKernel(ykernel, xsize*ysize*zsize, min(ysize, (int) maxSize));
zkernel.setArg<cl::Buffer>(0, data.getDeviceBuffer());
zkernel.setArg<cl_float>(1, forward ? 1.0f : -1.0f);
context.executeKernel(zkernel, xsize*ysize*zsize, min(zsize, (int) maxSize));
} }
int OpenCLFFT3D::findLegalDimension(int minimum) { int OpenCLFFT3D::findLegalDimension(int minimum) {
...@@ -73,23 +76,28 @@ int OpenCLFFT3D::findLegalDimension(int minimum) { ...@@ -73,23 +76,28 @@ int OpenCLFFT3D::findLegalDimension(int minimum) {
} }
} }
cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int xmult, int ymult, int zmult) { cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize) {
bool loopRequired = (context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
stringstream source; stringstream source;
int unfactored = xsize;
int stage = 0; int stage = 0;
int L = xsize; int L = zsize;
int m = 1; int m = 1;
// Factor xsize, generating an appropriate block of code for each factor. // Factor zsize, generating an appropriate block of code for each factor.
while (unfactored > 1) { while (L > 1) {
int input = stage%2; int input = stage%2;
int output = 1-input; int output = 1-input;
source<<"{\n"; source<<"{\n";
if (unfactored%5 == 0) { if (L%5 == 0) {
L = L/5; L = L/5;
source<<"// Pass "<<(stage+1)<<" (radix 5)\n"; source<<"// Pass "<<(stage+1)<<" (radix 5)\n";
source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n"; if (loopRequired)
source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n";
else {
source<<"if (get_local_id(0) < "<<(L*m)<<") {\n";
source<<"int i = get_local_id(0);\n";
}
source<<"int j = i/"<<m<<";\n"; source<<"int j = i/"<<m<<";\n";
source<<"float2 c0 = data"<<input<<"[i];\n"; source<<"float2 c0 = data"<<input<<"[i];\n";
source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
...@@ -109,18 +117,22 @@ cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int xmult, ...@@ -109,18 +117,22 @@ cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int xmult,
source<<"float2 d9 = sign*(float2) (d2.y+"<<coeff<<"*d3.y, -d2.x-"<<coeff<<"*d3.x);\n"; source<<"float2 d9 = sign*(float2) (d2.y+"<<coeff<<"*d3.y, -d2.x-"<<coeff<<"*d3.x);\n";
source<<"float2 d10 = sign*(float2) ("<<coeff<<"*d2.y-d3.y, d3.x-"<<coeff<<"*d2.x);\n"; source<<"float2 d10 = sign*(float2) ("<<coeff<<"*d2.y-d3.y, d3.x-"<<coeff<<"*d2.x);\n";
source<<"data"<<output<<"[i+4*j*"<<m<<"] = c0+d4;\n"; source<<"data"<<output<<"[i+4*j*"<<m<<"] = c0+d4;\n";
source<<"data"<<output<<"[i+(4*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<xsize<<"/"<<(5*L)<<"], d7+d9);\n"; source<<"data"<<output<<"[i+(4*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(5*L)<<"], d7+d9);\n";
source<<"data"<<output<<"[i+(4*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*xsize)<<"/"<<(5*L)<<"], d8+d10);\n"; source<<"data"<<output<<"[i+(4*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(5*L)<<"], d8+d10);\n";
source<<"data"<<output<<"[i+(4*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*xsize)<<"/"<<(5*L)<<"], d8-d10);\n"; source<<"data"<<output<<"[i+(4*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(5*L)<<"], d8-d10);\n";
source<<"data"<<output<<"[i+(4*j+4)*"<<m<<"] = multiplyComplex(w[j*"<<(4*xsize)<<"/"<<(5*L)<<"], d7-d9);\n"; source<<"data"<<output<<"[i+(4*j+4)*"<<m<<"] = multiplyComplex(w[j*"<<(4*zsize)<<"/"<<(5*L)<<"], d7-d9);\n";
source<<"}\n"; source<<"}\n";
m = m*5; m = m*5;
unfactored /= 5;
} }
else if (unfactored%4 == 0) { else if (L%4 == 0) {
L = L/4; L = L/4;
source<<"// Pass "<<(stage+1)<<" (radix 4)\n"; source<<"// Pass "<<(stage+1)<<" (radix 4)\n";
source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n"; if (loopRequired)
source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n";
else {
source<<"if (get_local_id(0) < "<<(L*m)<<") {\n";
source<<"int i = get_local_id(0);\n";
}
source<<"int j = i/"<<m<<";\n"; source<<"int j = i/"<<m<<";\n";
source<<"float2 c0 = data"<<input<<"[i];\n"; source<<"float2 c0 = data"<<input<<"[i];\n";
source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
...@@ -131,17 +143,21 @@ cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int xmult, ...@@ -131,17 +143,21 @@ cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int xmult,
source<<"float2 d2 = c1+c3;\n"; source<<"float2 d2 = c1+c3;\n";
source<<"float2 d3 = sign*(float2) (c1.y-c3.y, c3.x-c1.x);\n"; source<<"float2 d3 = sign*(float2) (c1.y-c3.y, c3.x-c1.x);\n";
source<<"data"<<output<<"[i+3*j*"<<m<<"] = d0+d2;\n"; source<<"data"<<output<<"[i+3*j*"<<m<<"] = d0+d2;\n";
source<<"data"<<output<<"[i+(3*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<xsize<<"/"<<(4*L)<<"], d1+d3);\n"; source<<"data"<<output<<"[i+(3*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(4*L)<<"], d1+d3);\n";
source<<"data"<<output<<"[i+(3*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*xsize)<<"/"<<(4*L)<<"], d0-d2);\n"; source<<"data"<<output<<"[i+(3*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(4*L)<<"], d0-d2);\n";
source<<"data"<<output<<"[i+(3*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*xsize)<<"/"<<(4*L)<<"], d1-d3);\n"; source<<"data"<<output<<"[i+(3*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(4*L)<<"], d1-d3);\n";
source<<"}\n"; source<<"}\n";
m = m*4; m = m*4;
unfactored /= 4;
} }
else if (unfactored%3 == 0) { else if (L%3 == 0) {
L = L/3; L = L/3;
source<<"// Pass "<<(stage+1)<<" (radix 3)\n"; source<<"// Pass "<<(stage+1)<<" (radix 3)\n";
source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n"; if (loopRequired)
source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n";
else {
source<<"if (get_local_id(0) < "<<(L*m)<<") {\n";
source<<"int i = get_local_id(0);\n";
}
source<<"int j = i/"<<m<<";\n"; source<<"int j = i/"<<m<<";\n";
source<<"float2 c0 = data"<<input<<"[i];\n"; source<<"float2 c0 = data"<<input<<"[i];\n";
source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
...@@ -150,27 +166,30 @@ cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int xmult, ...@@ -150,27 +166,30 @@ cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int xmult,
source<<"float2 d1 = c0-0.5f*d0;\n"; source<<"float2 d1 = c0-0.5f*d0;\n";
source<<"float2 d2 = sign*"<<OpenCLExpressionUtilities::doubleToString(sin(M_PI/3.0))<<"*(float2) (c1.y-c2.y, c2.x-c1.x);\n"; source<<"float2 d2 = sign*"<<OpenCLExpressionUtilities::doubleToString(sin(M_PI/3.0))<<"*(float2) (c1.y-c2.y, c2.x-c1.x);\n";
source<<"data"<<output<<"[i+2*j*"<<m<<"] = c0+d0;\n"; source<<"data"<<output<<"[i+2*j*"<<m<<"] = c0+d0;\n";
source<<"data"<<output<<"[i+(2*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<xsize<<"/"<<(3*L)<<"], d1+d2);\n"; source<<"data"<<output<<"[i+(2*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(3*L)<<"], d1+d2);\n";
source<<"data"<<output<<"[i+(2*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*xsize)<<"/"<<(3*L)<<"], d1-d2);\n"; source<<"data"<<output<<"[i+(2*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(3*L)<<"], d1-d2);\n";
source<<"}\n"; source<<"}\n";
m = m*3; m = m*3;
unfactored /= 3;
} }
else if (unfactored%2 == 0) { else if (L%2 == 0) {
L = L/2; L = L/2;
source<<"// Pass "<<(stage+1)<<" (radix 2)\n"; source<<"// Pass "<<(stage+1)<<" (radix 2)\n";
source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n"; if (loopRequired)
source<<"for (int i = get_local_id(0); i < "<<(L*m)<<"; i += get_local_size(0)) {\n";
else {
source<<"if (get_local_id(0) < "<<(L*m)<<") {\n";
source<<"int i = get_local_id(0);\n";
}
source<<"int j = i/"<<m<<";\n"; source<<"int j = i/"<<m<<";\n";
source<<"float2 c0 = data"<<input<<"[i];\n"; source<<"float2 c0 = data"<<input<<"[i];\n";
source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n"; source<<"float2 c1 = data"<<input<<"[i+"<<(L*m)<<"];\n";
source<<"data"<<output<<"[i+j*"<<m<<"] = c0+c1;\n"; source<<"data"<<output<<"[i+j*"<<m<<"] = c0+c1;\n";
source<<"data"<<output<<"[i+(j+1)*"<<m<<"] = multiplyComplex(w[j*"<<xsize<<"/"<<(2*L)<<"], c0-c1);\n"; source<<"data"<<output<<"[i+(j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(2*L)<<"], c0-c1);\n";
source<<"}\n"; source<<"}\n";
m = m*2; m = m*2;
unfactored /= 2;
} }
else else
throw OpenMMException("Illegal size for FFT: "+OpenCLExpressionUtilities::intToString(xsize)); throw OpenMMException("Illegal size for FFT: "+OpenCLExpressionUtilities::intToString(zsize));
source<<"barrier(CLK_LOCAL_MEM_FENCE);\n"; source<<"barrier(CLK_LOCAL_MEM_FENCE);\n";
source<<"}\n"; source<<"}\n";
++stage; ++stage;
...@@ -178,22 +197,25 @@ cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int xmult, ...@@ -178,22 +197,25 @@ cl::Kernel OpenCLFFT3D::createKernel(int xsize, int ysize, int zsize, int xmult,
// Create the kernel. // Create the kernel.
source<<"for (int i = get_local_id(0); i < XSIZE; i += get_local_size(0))\n"; if (loopRequired) {
source<<"matrix[i*XMULT+y*YMULT+z*ZMULT] = data"<<(stage%2)<<"[i];\n"; source<<"for (int z = get_local_id(0); z < ZSIZE; z += get_local_size(0))\n";
source<<"out[y*(ZSIZE*XSIZE)+z*XSIZE+x] = data"<<(stage%2)<<"[z];\n";
}
else
source<<"out[y*(ZSIZE*XSIZE)+get_local_id(0)*XSIZE+x] = data"<<(stage%2)<<"[get_local_id(0)];\n";
source<<"barrier(CLK_GLOBAL_MEM_FENCE);"; source<<"barrier(CLK_GLOBAL_MEM_FENCE);";
map<string, string> replacements; map<string, string> replacements;
replacements["XSIZE"] = OpenCLExpressionUtilities::intToString(xsize); replacements["XSIZE"] = OpenCLExpressionUtilities::intToString(xsize);
replacements["YSIZE"] = OpenCLExpressionUtilities::intToString(ysize); replacements["YSIZE"] = OpenCLExpressionUtilities::intToString(ysize);
replacements["ZSIZE"] = OpenCLExpressionUtilities::intToString(zsize); replacements["ZSIZE"] = OpenCLExpressionUtilities::intToString(zsize);
replacements["XMULT"] = OpenCLExpressionUtilities::intToString(xmult);
replacements["YMULT"] = OpenCLExpressionUtilities::intToString(ymult);
replacements["ZMULT"] = OpenCLExpressionUtilities::intToString(zmult);
replacements["M_PI"] = OpenCLExpressionUtilities::doubleToString(M_PI); replacements["M_PI"] = OpenCLExpressionUtilities::doubleToString(M_PI);
replacements["COMPUTE_FFT"] = source.str(); replacements["COMPUTE_FFT"] = source.str();
if (loopRequired)
replacements["LOOP_REQUIRED"] = "1";
cl::Program program = context.createProgram(context.replaceStrings(OpenCLKernelSources::fft, replacements)); cl::Program program = context.createProgram(context.replaceStrings(OpenCLKernelSources::fft, replacements));
cl::Kernel kernel(program, "execFFT"); cl::Kernel kernel(program, "execFFT");
kernel.setArg(2, xsize*sizeof(mm_float2), NULL); kernel.setArg(3, zsize*sizeof(mm_float2), NULL);
kernel.setArg(3, xsize*sizeof(mm_float2), NULL); kernel.setArg(4, zsize*sizeof(mm_float2), NULL);
kernel.setArg(4, xsize*sizeof(mm_float2), NULL); kernel.setArg(5, zsize*sizeof(mm_float2), NULL);
return kernel; return kernel;
} }
...@@ -64,12 +64,15 @@ public: ...@@ -64,12 +64,15 @@ public:
*/ */
OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize); OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize);
/** /**
* Perform an in-place Fourier transform. * Perform a Fourier transform. The transform cannot be done in-place: the input and output
* arrays must be different. Also, the input array is used as workspace, so its contents
* are destroyed.
* *
* @param data the data to transform, ordered such that data[x*ysize*zsize + y*zsize + z] contains element (x, y, z) * @param in the data to transform, ordered such that in[x*ysize*zsize + y*zsize + z] contains element (x, y, z)
* @param out on exit, this contains the transformed data
* @param forward true to perform a forward transform, false to perform an inverse transform * @param forward true to perform a forward transform, false to perform an inverse transform
*/ */
void execFFT(OpenCLArray<mm_float2>& data, bool forward = true); void execFFT(OpenCLArray<mm_float2>& in, OpenCLArray<mm_float2>& out, bool forward = true);
/** /**
* Get the smallest legal size for a dimension of the grid (that is, a size with no prime * Get the smallest legal size for a dimension of the grid (that is, a size with no prime
* factors other than 2, 3, and 5). * factors other than 2, 3, and 5).
...@@ -78,7 +81,7 @@ public: ...@@ -78,7 +81,7 @@ public:
*/ */
static int findLegalDimension(int minimum); static int findLegalDimension(int minimum);
private: private:
cl::Kernel createKernel(int xsize, int ysize, int zsize, int xmult, int ymult, int zmult); cl::Kernel createKernel(int xsize, int ysize, int zsize);
int xsize, ysize, zsize; int xsize, ysize, zsize;
OpenCLContext& context; OpenCLContext& context;
cl::Kernel xkernel, ykernel, zkernel; cl::Kernel xkernel, ykernel, zkernel;
......
...@@ -1139,6 +1139,8 @@ OpenCLCalcNonbondedForceKernel::~OpenCLCalcNonbondedForceKernel() { ...@@ -1139,6 +1139,8 @@ OpenCLCalcNonbondedForceKernel::~OpenCLCalcNonbondedForceKernel() {
delete cosSinSums; delete cosSinSums;
if (pmeGrid != NULL) if (pmeGrid != NULL)
delete pmeGrid; delete pmeGrid;
if (pmeGrid2 != NULL)
delete pmeGrid2;
if (pmeBsplineModuliX != NULL) if (pmeBsplineModuliX != NULL)
delete pmeBsplineModuliX; delete pmeBsplineModuliX;
if (pmeBsplineModuliY != NULL) if (pmeBsplineModuliY != NULL)
...@@ -1266,6 +1268,7 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb ...@@ -1266,6 +1268,7 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
// Create required data structures. // Create required data structures.
pmeGrid = new OpenCLArray<mm_float2>(cl, gridSizeX*gridSizeY*gridSizeZ, "pmeGrid"); pmeGrid = new OpenCLArray<mm_float2>(cl, gridSizeX*gridSizeY*gridSizeZ, "pmeGrid");
pmeGrid2 = new OpenCLArray<mm_float2>(cl, gridSizeX*gridSizeY*gridSizeZ, "pmeGrid2");
pmeBsplineModuliX = new OpenCLArray<cl_float>(cl, gridSizeX, "pmeBsplineModuliX"); pmeBsplineModuliX = new OpenCLArray<cl_float>(cl, gridSizeX, "pmeBsplineModuliX");
pmeBsplineModuliY = new OpenCLArray<cl_float>(cl, gridSizeY, "pmeBsplineModuliY"); pmeBsplineModuliY = new OpenCLArray<cl_float>(cl, gridSizeY, "pmeBsplineModuliY");
pmeBsplineModuliZ = new OpenCLArray<cl_float>(cl, gridSizeZ, "pmeBsplineModuliZ"); pmeBsplineModuliZ = new OpenCLArray<cl_float>(cl, gridSizeZ, "pmeBsplineModuliZ");
...@@ -1419,7 +1422,7 @@ double OpenCLCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ ...@@ -1419,7 +1422,7 @@ double OpenCLCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeSpreadChargeKernel.setArg<cl::Buffer>(2, pmeAtomRange->getDeviceBuffer()); pmeSpreadChargeKernel.setArg<cl::Buffer>(2, pmeAtomRange->getDeviceBuffer());
pmeSpreadChargeKernel.setArg<cl::Buffer>(3, pmeGrid->getDeviceBuffer()); pmeSpreadChargeKernel.setArg<cl::Buffer>(3, pmeGrid->getDeviceBuffer());
pmeSpreadChargeKernel.setArg<cl::Buffer>(4, pmeBsplineTheta->getDeviceBuffer()); pmeSpreadChargeKernel.setArg<cl::Buffer>(4, pmeBsplineTheta->getDeviceBuffer());
pmeConvolutionKernel.setArg<cl::Buffer>(0, pmeGrid->getDeviceBuffer()); pmeConvolutionKernel.setArg<cl::Buffer>(0, pmeGrid2->getDeviceBuffer());
pmeConvolutionKernel.setArg<cl::Buffer>(1, cl.getEnergyBuffer().getDeviceBuffer()); pmeConvolutionKernel.setArg<cl::Buffer>(1, cl.getEnergyBuffer().getDeviceBuffer());
pmeConvolutionKernel.setArg<cl::Buffer>(2, pmeBsplineModuliX->getDeviceBuffer()); pmeConvolutionKernel.setArg<cl::Buffer>(2, pmeBsplineModuliX->getDeviceBuffer());
pmeConvolutionKernel.setArg<cl::Buffer>(3, pmeBsplineModuliY->getDeviceBuffer()); pmeConvolutionKernel.setArg<cl::Buffer>(3, pmeBsplineModuliY->getDeviceBuffer());
...@@ -1474,11 +1477,11 @@ double OpenCLCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ ...@@ -1474,11 +1477,11 @@ double OpenCLCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
else else
cl.executeKernel(pmeSpreadChargeKernel, cl.getNumAtoms()); cl.executeKernel(pmeSpreadChargeKernel, cl.getNumAtoms());
} }
fft->execFFT(*pmeGrid, true); fft->execFFT(*pmeGrid, *pmeGrid2, true);
pmeConvolutionKernel.setArg<mm_float4>(5, invBoxSize); pmeConvolutionKernel.setArg<mm_float4>(5, invBoxSize);
pmeConvolutionKernel.setArg<cl_float>(6, (float) (1.0/(M_PI*boxSize.x*boxSize.y*boxSize.z))); pmeConvolutionKernel.setArg<cl_float>(6, (float) (1.0/(M_PI*boxSize.x*boxSize.y*boxSize.z)));
cl.executeKernel(pmeConvolutionKernel, cl.getNumAtoms()); cl.executeKernel(pmeConvolutionKernel, cl.getNumAtoms());
fft->execFFT(*pmeGrid, false); fft->execFFT(*pmeGrid2, *pmeGrid, false);
pmeInterpolateForceKernel.setArg<mm_float4>(5, boxSize); pmeInterpolateForceKernel.setArg<mm_float4>(5, boxSize);
pmeInterpolateForceKernel.setArg<mm_float4>(6, invBoxSize); pmeInterpolateForceKernel.setArg<mm_float4>(6, invBoxSize);
cl.executeKernel(pmeInterpolateForceKernel, cl.getNumAtoms()); cl.executeKernel(pmeInterpolateForceKernel, cl.getNumAtoms());
......
...@@ -475,7 +475,7 @@ private: ...@@ -475,7 +475,7 @@ private:
class OpenCLCalcNonbondedForceKernel : public CalcNonbondedForceKernel { class OpenCLCalcNonbondedForceKernel : public CalcNonbondedForceKernel {
public: public:
OpenCLCalcNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcNonbondedForceKernel(name, platform), OpenCLCalcNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, System& system) : CalcNonbondedForceKernel(name, platform),
hasInitializedKernel(false), cl(cl), sigmaEpsilon(NULL), exceptionParams(NULL), exceptionIndices(NULL), cosSinSums(NULL), pmeGrid(NULL), hasInitializedKernel(false), cl(cl), sigmaEpsilon(NULL), exceptionParams(NULL), exceptionIndices(NULL), cosSinSums(NULL), pmeGrid(NULL), pmeGrid2(NULL),
pmeBsplineModuliX(NULL), pmeBsplineModuliY(NULL), pmeBsplineModuliZ(NULL), pmeBsplineTheta(NULL), pmeBsplineDtheta(NULL), pmeAtomRange(NULL), pmeBsplineModuliX(NULL), pmeBsplineModuliY(NULL), pmeBsplineModuliZ(NULL), pmeBsplineTheta(NULL), pmeBsplineDtheta(NULL), pmeAtomRange(NULL),
pmeAtomGridIndex(NULL), sort(NULL), fft(NULL) { pmeAtomGridIndex(NULL), sort(NULL), fft(NULL) {
} }
...@@ -504,6 +504,7 @@ private: ...@@ -504,6 +504,7 @@ private:
OpenCLArray<mm_int4>* exceptionIndices; OpenCLArray<mm_int4>* exceptionIndices;
OpenCLArray<mm_float2>* cosSinSums; OpenCLArray<mm_float2>* cosSinSums;
OpenCLArray<mm_float2>* pmeGrid; OpenCLArray<mm_float2>* pmeGrid;
OpenCLArray<mm_float2>* pmeGrid2;
OpenCLArray<cl_float>* pmeBsplineModuliX; OpenCLArray<cl_float>* pmeBsplineModuliX;
OpenCLArray<cl_float>* pmeBsplineModuliY; OpenCLArray<cl_float>* pmeBsplineModuliY;
OpenCLArray<cl_float>* pmeBsplineModuliZ; OpenCLArray<cl_float>* pmeBsplineModuliZ;
......
...@@ -6,15 +6,19 @@ float2 multiplyComplex(float2 c1, float2 c2) { ...@@ -6,15 +6,19 @@ float2 multiplyComplex(float2 c1, float2 c2) {
* Perform a 1D FFT on each row along one axis. * Perform a 1D FFT on each row along one axis.
*/ */
__kernel void execFFT(__global float2* matrix, float sign, __local float2* w, __local float2* data0, __local float2* data1) { __kernel void execFFT(__global float2* in, __global float2* out, float sign, __local float2* w, __local float2* data0, __local float2* data1) {
for (int i = get_local_id(0); i < XSIZE; i += get_local_size(0)) for (int i = get_local_id(0); i < ZSIZE; i += get_local_size(0))
w[i] = (float2) (cos(-sign*i*2*M_PI/XSIZE), sin(-sign*i*2*M_PI/XSIZE)); w[i] = (float2) (cos(-sign*i*2*M_PI/ZSIZE), sin(-sign*i*2*M_PI/ZSIZE));
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
for (int index = get_group_id(0); index < YSIZE*ZSIZE; index += get_num_groups(0)) { for (int index = get_group_id(0); index < XSIZE*YSIZE; index += get_num_groups(0)) {
int z = index/YSIZE; int x = index/YSIZE;
int y = index-z*YSIZE; int y = index-x*YSIZE;
for (int i = get_local_id(0); i < XSIZE; i += get_local_size(0)) #ifdef LOOP_REQUIRED
data0[i] = matrix[i*XMULT+y*YMULT+z*ZMULT]; for (int z = get_local_id(0); z < ZSIZE; z += get_local_size(0))
data0[z] = in[x*(YSIZE*ZSIZE)+y*ZSIZE+z];
#else
data0[get_local_id(0)] = in[x*(YSIZE*ZSIZE)+y*ZSIZE+get_local_id(0)];
#endif
barrier(CLK_LOCAL_MEM_FENCE); barrier(CLK_LOCAL_MEM_FENCE);
COMPUTE_FFT COMPUTE_FFT
} }
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2011 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
/**
* This tests the OpenCL implementation of sorting.
*/
#include "../../../tests/AssertionUtilities.h"
#include "../src/OpenCLArray.h"
#include "../src/OpenCLContext.h"
#include "../src/OpenCLFFT3D.h"
#include "../src/OpenCLSort.h"
#include "../src/SimTKReference/fftpack.h"
#include "sfmt/SFMT.h"
#include "openmm/System.h"
#include <iostream>
#include <cmath>
#include <set>
using namespace OpenMM;
using namespace std;
void testTransform() {
System system;
system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(1, "");
OpenCLContext& context = *platformData.contexts[0];
context.initialize(system);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
int xsize = 32, ysize = 25, zsize = 30;
vector<mm_float2> original(xsize*ysize*zsize);
vector<t_complex> reference(original.size());
for (int i = 0; i < (int) original.size(); i++) {
mm_float2 value = mm_float2(genrand_real2(sfmt), genrand_real2(sfmt));
original[i] = value;
reference[i] = t_complex(value.x, value.y);
}
OpenCLArray<mm_float2> grid1(context, original.size(), "grid1");
OpenCLArray<mm_float2> grid2(context, original.size(), "grid2");
grid1.upload(original);
OpenCLFFT3D fft(context, xsize, ysize, zsize);
// Perform a forward FFT, then verify the result is correct.
fft.execFFT(grid1, grid2, true);
vector<mm_float2> result;
grid2.download(result);
fftpack_t plan;
fftpack_init_3d(&plan, xsize, ysize, zsize);
fftpack_exec_3d(plan, FFTPACK_FORWARD, &reference[0], &reference[0]);
for (int i = 0; i < (int) result.size(); ++i) {
ASSERT_EQUAL_TOL(reference[i].re, result[i].x, 1e-4);
ASSERT_EQUAL_TOL(reference[i].im, result[i].y, 1e-4);
}
fftpack_destroy(plan);
// Perform a backward transform and see if we get the original values.
fft.execFFT(grid2, grid1, false);
grid1.download(result);
double scale = 1.0/(xsize*ysize*zsize);
for (int i = 0; i < (int) result.size(); ++i) {
ASSERT_EQUAL_TOL(original[i].x, scale*result[i].x, 1e-4);
ASSERT_EQUAL_TOL(original[i].y, scale*result[i].y, 1e-4);
}
}
int main() {
try {
testTransform();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment