Commit d3e91b15 authored by peastman's avatar peastman
Browse files

Ported OpenCL FFT to CUDA

parent 1b1ea94f
#ifndef __OPENMM_CUDAFFT3D_H__
#define __OPENMM_CUDAFFT3D_H__
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "CudaArray.h"
namespace OpenMM {
/**
* This class performs three dimensional Fast Fourier Transforms. It is based on the
* mixed radix algorithm described in
* <p>
* Takahashi, D. and Kanada, Y., "High-Performance Radix-2, 3 and 5 Parallel 1-D Complex
* FFT Algorithms for Distributed-Memory Parallel Computers." Journal of Supercomputing,
* 15, 207–228 (2000).
* <p>
* This class places certain restrictions on the allowed dimensions of the grid. First,
* the size of each dimension may have no prime factors other than 2, 3, 5, and 7. You
* can call findLegalDimension() to determine the smallest size that satisfies this
* requirement and is greater than or equal to a specified minimum size. Second, the size
* of each dimension must be small enough to compute each 1D transform entirely in local
* memory with one work unit per data point. This will vary between platforms, but is
* typically at least 512.
* <p>
* Note that this class performs an unnormalized transform. That means that if you perform
* a forward transform followed immediately by an inverse transform, the effect is to
* multiply every value of the original data set by the total number of data points.
*/
class OPENMM_EXPORT_CUDA CudaFFT3D {
public:
/**
* Create an CudaFFT3D object for performing transforms of a particular size.
*
* @param context the context in which to perform calculations
* @param xsize the first dimension of the data sets on which FFTs will be performed
* @param ysize the second dimension of the data sets on which FFTs will be performed
* @param zsize the third dimension of the data sets on which FFTs will be performed
* @param realToComplex if true, a real-to-complex transform will be done. Otherwise, it is complex-to-complex.
*/
CudaFFT3D(CudaContext& context, int xsize, int ysize, int zsize, bool realToComplex=false);
/**
* Perform a Fourier transform. The transform cannot be done in-place: the input and output
* arrays must be different. Also, the input array is used as workspace, so its contents
* are destroyed. This also means that both arrays must be large enough to hold complex values,
* even when performing a real-to-complex transform.
* <p>
* When performing a real-to-complex transform, the output data is of size xsize*ysize*(zsize/2+1)
* and contains only the non-redundant elements.
*
* @param in the data to transform, ordered such that in[x*ysize*zsize + y*zsize + z] contains element (x, y, z)
* @param out on exit, this contains the transformed data
* @param forward true to perform a forward transform, false to perform an inverse transform
*/
void execFFT(CudaArray& in, CudaArray& out, bool forward = true);
/**
* Get the smallest legal size for a dimension of the grid (that is, a size with no prime
* factors other than 2, 3, 5, and 7).
*
* @param minimum the minimum size the return value must be greater than or equal to
*/
static int findLegalDimension(int minimum);
private:
CUfunction createKernel(int xsize, int ysize, int zsize, int& threads, int axis, bool forward, bool inputIsReal);
int xsize, ysize, zsize;
int xthreads, ythreads, zthreads;
bool packRealAsComplex;
CudaContext& context;
CUfunction xkernel, ykernel, zkernel;
CUfunction invxkernel, invykernel, invzkernel;
CUfunction packForwardKernel, unpackForwardKernel, packBackwardKernel, unpackBackwardKernel;
};
} // namespace OpenMM
#endif // __OPENMM_CUDAFFT3D_H__
......@@ -30,6 +30,7 @@
#include "CudaPlatform.h"
#include "CudaArray.h"
#include "CudaContext.h"
#include "CudaFFT3D.h"
#include "CudaParameterSet.h"
#include "CudaSort.h"
#include "openmm/kernels.h"
......@@ -588,7 +589,7 @@ class CudaCalcNonbondedForceKernel : public CalcNonbondedForceKernel {
public:
CudaCalcNonbondedForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcNonbondedForceKernel(name, platform),
cu(cu), hasInitializedFFT(false), sigmaEpsilon(NULL), exceptionParams(NULL), cosSinSums(NULL), directPmeGrid(NULL), reciprocalPmeGrid(NULL),
pmeBsplineModuliX(NULL), pmeBsplineModuliY(NULL), pmeBsplineModuliZ(NULL), pmeAtomRange(NULL), pmeAtomGridIndex(NULL), sort(NULL), pmeio(NULL) {
pmeBsplineModuliX(NULL), pmeBsplineModuliY(NULL), pmeBsplineModuliZ(NULL), pmeAtomRange(NULL), pmeAtomGridIndex(NULL), sort(NULL), fft(NULL), pmeio(NULL) {
}
~CudaCalcNonbondedForceKernel();
/**
......@@ -649,6 +650,7 @@ private:
PmeIO* pmeio;
CUstream pmeStream;
CUevent pmeSyncEvent;
CudaFFT3D* fft;
cufftHandle fftForward;
cufftHandle fftBackward;
CUfunction ewaldSumsKernel;
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "CudaFFT3D.h"
#include "CudaContext.h"
#include "CudaKernelSources.h"
#include "SimTKOpenMMRealType.h"
#include <map>
#include <sstream>
#include <string>
using namespace OpenMM;
using namespace std;
CudaFFT3D::CudaFFT3D(CudaContext& context, int xsize, int ysize, int zsize, bool realToComplex) :
context(context), xsize(xsize), ysize(ysize), zsize(zsize) {
packRealAsComplex = false;
int packedXSize = xsize;
int packedYSize = ysize;
int packedZSize = zsize;
if (realToComplex) {
// If any axis size is even, we can pack the real values into a complex grid that is only half as large.
// Look for an appropriate axis.
packRealAsComplex = true;
int packedAxis, bufferSize;
if (xsize%2 == 0) {
packedAxis = 0;
packedXSize /= 2;
bufferSize = packedXSize;
}
else if (ysize%2 == 0) {
packedAxis = 1;
packedYSize /= 2;
bufferSize = packedYSize;
}
else if (zsize%2 == 0) {
packedAxis = 2;
packedZSize /= 2;
bufferSize = packedZSize;
}
else
packRealAsComplex = false;
if (packRealAsComplex) {
// Build the kernels for packing and unpacking the data.
map<string, string> defines;
defines["XSIZE"] = context.intToString(xsize);
defines["YSIZE"] = context.intToString(ysize);
defines["ZSIZE"] = context.intToString(zsize);
defines["PACKED_AXIS"] = context.intToString(packedAxis);
defines["PACKED_XSIZE"] = context.intToString(packedXSize);
defines["PACKED_YSIZE"] = context.intToString(packedYSize);
defines["PACKED_ZSIZE"] = context.intToString(packedZSize);
CUmodule module = context.createModule(CudaKernelSources::vectorOps+CudaKernelSources::fftR2C, defines);
packForwardKernel = context.getKernel(module, "packForwardData");
unpackForwardKernel = context.getKernel(module, "unpackForwardData");
packBackwardKernel = context.getKernel(module, "packBackwardData");
unpackBackwardKernel = context.getKernel(module, "unpackBackwardData");
}
}
bool inputIsReal = (realToComplex && !packRealAsComplex);
zkernel = createKernel(packedXSize, packedYSize, packedZSize, zthreads, 0, true, inputIsReal);
xkernel = createKernel(packedYSize, packedZSize, packedXSize, xthreads, 1, true, inputIsReal);
ykernel = createKernel(packedZSize, packedXSize, packedYSize, ythreads, 2, true, inputIsReal);
invzkernel = createKernel(packedXSize, packedYSize, packedZSize, zthreads, 0, false, inputIsReal);
invxkernel = createKernel(packedYSize, packedZSize, packedXSize, xthreads, 1, false, inputIsReal);
invykernel = createKernel(packedZSize, packedXSize, packedYSize, ythreads, 2, false, inputIsReal);
}
void CudaFFT3D::execFFT(CudaArray& in, CudaArray& out, bool forward) {
CUfunction kernel1 = (forward ? zkernel : invzkernel);
CUfunction kernel2 = (forward ? xkernel : invxkernel);
CUfunction kernel3 = (forward ? ykernel : invykernel);
void* args1[] = {&in.getDevicePointer(), &out.getDevicePointer()};
void* args2[] = {&out.getDevicePointer(), &in.getDevicePointer()};
if (packRealAsComplex) {
CUfunction packKernel = (forward ? packForwardKernel : packBackwardKernel);
CUfunction unpackKernel = (forward ? unpackForwardKernel : unpackBackwardKernel);
int gridSize = xsize*ysize*zsize/2;
// Pack the data into a half sized grid.
context.executeKernel(packKernel, args1, gridSize, 128);
// Perform the FFT.
context.executeKernel(kernel1, args2, gridSize, zthreads);
context.executeKernel(kernel2, args1, gridSize, xthreads);
context.executeKernel(kernel3, args2, gridSize, ythreads);
// Unpack the data.
context.executeKernel(unpackKernel, args1, gridSize, 128);
}
else {
context.executeKernel(kernel1, args1, xsize*ysize*zsize, zthreads);
context.executeKernel(kernel2, args2, xsize*ysize*zsize, xthreads);
context.executeKernel(kernel3, args1, xsize*ysize*zsize, ythreads);
}
}
int CudaFFT3D::findLegalDimension(int minimum) {
if (minimum < 1)
return 1;
while (true) {
// Attempt to factor the current value.
int unfactored = minimum;
for (int factor = 2; factor < 8; factor++) {
while (unfactored > 1 && unfactored%factor == 0)
unfactored /= factor;
}
if (unfactored == 1)
return minimum;
minimum++;
}
}
static int getSmallestRadix(int size) {
int minRadix = 1;
int unfactored = size;
while (unfactored%7 == 0) {
minRadix = 7;
unfactored /= 7;
}
while (unfactored%5 == 0) {
minRadix = 5;
unfactored /= 5;
}
while (unfactored%4 == 0) {
minRadix = 4;
unfactored /= 4;
}
while (unfactored%3 == 0) {
minRadix = 3;
unfactored /= 3;
}
while (unfactored%2 == 0) {
minRadix = 2;
unfactored /= 2;
}
return minRadix;
}
CUfunction CudaFFT3D::createKernel(int xsize, int ysize, int zsize, int& threads, int axis, bool forward, bool inputIsReal) {
int maxThreads = 256;//std::min(256, (int) context.getDevice().getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>());
// while (maxThreads > 128 && maxThreads-64 >= zsize)
// maxThreads -= 64;
int threadsPerBlock = zsize/getSmallestRadix(zsize);
bool isCPU = false;//context.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU;
bool loopRequired = (threadsPerBlock > maxThreads || isCPU);
stringstream source;
int blocksPerGroup = (loopRequired ? 1 : max(1, maxThreads/threadsPerBlock));
int stage = 0;
int L = zsize;
int m = 1;
// Factor zsize, generating an appropriate block of code for each factor.
while (L > 1) {
int input = stage%2;
int output = 1-input;
int radix;
if (L%7 == 0)
radix = 7;
else if (L%5 == 0)
radix = 5;
else if (L%4 == 0)
radix = 4;
else if (L%3 == 0)
radix = 3;
else if (L%2 == 0)
radix = 2;
else
throw OpenMMException("Illegal size for FFT: "+context.intToString(zsize));
source<<"{\n";
L = L/radix;
source<<"// Pass "<<(stage+1)<<" (radix "<<radix<<")\n";
if (loopRequired) {
source<<"for (int i = threadIdx.x; i < "<<(L*m)<<"; i += blockDim.x) {\n";
source<<"int base = i;\n";
}
else {
if (L*m < threadsPerBlock)
source<<"if (threadIdx.x < "<<(blocksPerGroup*L*m)<<") {\n";
else
source<<"{\n";
source<<"int block = threadIdx.x/"<<(L*m)<<";\n";
source<<"int i = threadIdx.x-block*"<<(L*m)<<";\n";
source<<"int base = i+block*"<<zsize<<";\n";
}
source<<"int j = i/"<<m<<";\n";
if (radix == 7) {
source<<"real2 c0 = data"<<input<<"[base];\n";
source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n";
source<<"real2 c2 = data"<<input<<"[base+"<<(2*L*m)<<"];\n";
source<<"real2 c3 = data"<<input<<"[base+"<<(3*L*m)<<"];\n";
source<<"real2 c4 = data"<<input<<"[base+"<<(4*L*m)<<"];\n";
source<<"real2 c5 = data"<<input<<"[base+"<<(5*L*m)<<"];\n";
source<<"real2 c6 = data"<<input<<"[base+"<<(6*L*m)<<"];\n";
source<<"real2 d0 = c1+c6;\n";
source<<"real2 d1 = c1-c6;\n";
source<<"real2 d2 = c2+c5;\n";
source<<"real2 d3 = c2-c5;\n";
source<<"real2 d4 = c4+c3;\n";
source<<"real2 d5 = c4-c3;\n";
source<<"real2 d6 = d2+d0;\n";
source<<"real2 d7 = d5+d3;\n";
source<<"real2 b0 = c0+d6+d4;\n";
source<<"real2 b1 = "<<context.doubleToString((cos(2*M_PI/7)+cos(4*M_PI/7)+cos(6*M_PI/7))/3-1)<<"*(d6+d4);\n";
source<<"real2 b2 = "<<context.doubleToString((2*cos(2*M_PI/7)-cos(4*M_PI/7)-cos(6*M_PI/7))/3)<<"*(d0-d4);\n";
source<<"real2 b3 = "<<context.doubleToString((cos(2*M_PI/7)-2*cos(4*M_PI/7)+cos(6*M_PI/7))/3)<<"*(d4-d2);\n";
source<<"real2 b4 = "<<context.doubleToString((cos(2*M_PI/7)+cos(4*M_PI/7)-2*cos(6*M_PI/7))/3)<<"*(d2-d0);\n";
source<<"real2 b5 = -(SIGN)*"<<context.doubleToString((sin(2*M_PI/7)+sin(4*M_PI/7)-sin(6*M_PI/7))/3)<<"*(d7+d1);\n";
source<<"real2 b6 = -(SIGN)*"<<context.doubleToString((2*sin(2*M_PI/7)-sin(4*M_PI/7)+sin(6*M_PI/7))/3)<<"*(d1-d5);\n";
source<<"real2 b7 = -(SIGN)*"<<context.doubleToString((sin(2*M_PI/7)-2*sin(4*M_PI/7)-sin(6*M_PI/7))/3)<<"*(d5-d3);\n";
source<<"real2 b8 = -(SIGN)*"<<context.doubleToString((sin(2*M_PI/7)+sin(4*M_PI/7)+2*sin(6*M_PI/7))/3)<<"*(d3-d1);\n";
source<<"real2 t0 = b0+b1;\n";
source<<"real2 t1 = b2+b3;\n";
source<<"real2 t2 = b4-b3;\n";
source<<"real2 t3 = -b2-b4;\n";
source<<"real2 t4 = b6+b7;\n";
source<<"real2 t5 = b8-b7;\n";
source<<"real2 t6 = -b8-b6;\n";
source<<"real2 t7 = t0+t1;\n";
source<<"real2 t8 = t0+t2;\n";
source<<"real2 t9 = t0+t3;\n";
source<<"real2 t10 = make_real2(t4.y+b5.y, -(t4.x+b5.x));\n";
source<<"real2 t11 = make_real2(t5.y+b5.y, -(t5.x+b5.x));\n";
source<<"real2 t12 = make_real2(t6.y+b5.y, -(t6.x+b5.x));\n";
source<<"data"<<output<<"[base+6*j*"<<m<<"] = b0;\n";
source<<"data"<<output<<"[base+(6*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(7*L)<<"], t7-t10);\n";
source<<"data"<<output<<"[base+(6*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(7*L)<<"], t9-t12);\n";
source<<"data"<<output<<"[base+(6*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(7*L)<<"], t8+t11);\n";
source<<"data"<<output<<"[base+(6*j+4)*"<<m<<"] = multiplyComplex(w[j*"<<(4*zsize)<<"/"<<(7*L)<<"], t8-t11);\n";
source<<"data"<<output<<"[base+(6*j+5)*"<<m<<"] = multiplyComplex(w[j*"<<(5*zsize)<<"/"<<(7*L)<<"], t9+t12);\n";
source<<"data"<<output<<"[base+(6*j+6)*"<<m<<"] = multiplyComplex(w[j*"<<(6*zsize)<<"/"<<(7*L)<<"], t7+t10);\n";
}
else if (radix == 5) {
source<<"real2 c0 = data"<<input<<"[base];\n";
source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n";
source<<"real2 c2 = data"<<input<<"[base+"<<(2*L*m)<<"];\n";
source<<"real2 c3 = data"<<input<<"[base+"<<(3*L*m)<<"];\n";
source<<"real2 c4 = data"<<input<<"[base+"<<(4*L*m)<<"];\n";
source<<"real2 d0 = c1+c4;\n";
source<<"real2 d1 = c2+c3;\n";
source<<"real2 d2 = "<<context.doubleToString(sin(0.4*M_PI))<<"*(c1-c4);\n";
source<<"real2 d3 = "<<context.doubleToString(sin(0.4*M_PI))<<"*(c2-c3);\n";
source<<"real2 d4 = d0+d1;\n";
source<<"real2 d5 = "<<context.doubleToString(0.25*sqrt(5.0))<<"*(d0-d1);\n";
source<<"real2 d6 = c0-0.25f*d4;\n";
source<<"real2 d7 = d6+d5;\n";
source<<"real2 d8 = d6-d5;\n";
string coeff = context.doubleToString(sin(0.2*M_PI)/sin(0.4*M_PI));
source<<"real2 d9 = (SIGN)*make_real2(d2.y+"<<coeff<<"*d3.y, -d2.x-"<<coeff<<"*d3.x);\n";
source<<"real2 d10 = (SIGN)*make_real2("<<coeff<<"*d2.y-d3.y, d3.x-"<<coeff<<"*d2.x);\n";
source<<"data"<<output<<"[base+4*j*"<<m<<"] = c0+d4;\n";
source<<"data"<<output<<"[base+(4*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(5*L)<<"], d7+d9);\n";
source<<"data"<<output<<"[base+(4*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(5*L)<<"], d8+d10);\n";
source<<"data"<<output<<"[base+(4*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(5*L)<<"], d8-d10);\n";
source<<"data"<<output<<"[base+(4*j+4)*"<<m<<"] = multiplyComplex(w[j*"<<(4*zsize)<<"/"<<(5*L)<<"], d7-d9);\n";
}
else if (radix == 4) {
source<<"real2 c0 = data"<<input<<"[base];\n";
source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n";
source<<"real2 c2 = data"<<input<<"[base+"<<(2*L*m)<<"];\n";
source<<"real2 c3 = data"<<input<<"[base+"<<(3*L*m)<<"];\n";
source<<"real2 d0 = c0+c2;\n";
source<<"real2 d1 = c0-c2;\n";
source<<"real2 d2 = c1+c3;\n";
source<<"real2 d3 = (SIGN)*make_real2(c1.y-c3.y, c3.x-c1.x);\n";
source<<"data"<<output<<"[base+3*j*"<<m<<"] = d0+d2;\n";
source<<"data"<<output<<"[base+(3*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(4*L)<<"], d1+d3);\n";
source<<"data"<<output<<"[base+(3*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(4*L)<<"], d0-d2);\n";
source<<"data"<<output<<"[base+(3*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(4*L)<<"], d1-d3);\n";
}
else if (radix == 3) {
source<<"real2 c0 = data"<<input<<"[base];\n";
source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n";
source<<"real2 c2 = data"<<input<<"[base+"<<(2*L*m)<<"];\n";
source<<"real2 d0 = c1+c2;\n";
source<<"real2 d1 = c0-0.5f*d0;\n";
source<<"real2 d2 = (SIGN)*"<<context.doubleToString(sin(M_PI/3.0))<<"*make_real2(c1.y-c2.y, c2.x-c1.x);\n";
source<<"data"<<output<<"[base+2*j*"<<m<<"] = c0+d0;\n";
source<<"data"<<output<<"[base+(2*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(3*L)<<"], d1+d2);\n";
source<<"data"<<output<<"[base+(2*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(3*L)<<"], d1-d2);\n";
}
else if (radix == 2) {
source<<"real2 c0 = data"<<input<<"[base];\n";
source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n";
source<<"data"<<output<<"[base+j*"<<m<<"] = c0+c1;\n";
source<<"data"<<output<<"[base+(j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(2*L)<<"], c0-c1);\n";
}
source<<"}\n";
m = m*radix;
source<<"__syncthreads();\n";
source<<"}\n";
++stage;
}
// Create the kernel.
bool outputIsReal = (inputIsReal && axis == 2 && !forward);
bool outputIsPacked = (inputIsReal && axis == 2 && forward);
string outputSuffix = (outputIsReal ? ".x" : "");
if (loopRequired || true) {
if (outputIsPacked)
source<<"if (index < XSIZE*YSIZE && x < XSIZE/2+1)\n";
else
source<<"if (index < XSIZE*YSIZE)\n";
source<<"for (int i = threadIdx.x-block*THREADS_PER_BLOCK; i < ZSIZE; i += THREADS_PER_BLOCK)\n";
if (outputIsPacked)
source<<"out[y*(ZSIZE*(XSIZE/2+1))+i*(XSIZE/2+1)+x] = data"<<(stage%2)<<"[i+block*ZSIZE]"<<outputSuffix<<";\n";
else
source<<"out[y*(ZSIZE*XSIZE)+i*XSIZE+x] = data"<<(stage%2)<<"[i+block*ZSIZE]"<<outputSuffix<<";\n";
}
else {
if (outputIsPacked) {
source<<"if (index < XSIZE*YSIZE && x < XSIZE/2+1)\n";
source<<"out[y*(ZSIZE*(XSIZE/2+1))+(threadIdx.x%ZSIZE)*(XSIZE/2+1)+x] = data"<<(stage%2)<<"[threadIdx.x]"<<outputSuffix<<";\n";
}
else {
source<<"if (index < XSIZE*YSIZE)\n";
source<<"out[y*(ZSIZE*XSIZE)+(threadIdx.x%ZSIZE)*XSIZE+x] = data"<<(stage%2)<<"[threadIdx.x]"<<outputSuffix<<";\n";
}
}
map<string, string> replacements;
replacements["XSIZE"] = context.intToString(xsize);
replacements["YSIZE"] = context.intToString(ysize);
replacements["ZSIZE"] = context.intToString(zsize);
replacements["BLOCKS_PER_GROUP"] = context.intToString(blocksPerGroup);
replacements["THREADS_PER_BLOCK"] = context.intToString(threadsPerBlock);
replacements["M_PI"] = context.doubleToString(M_PI);
replacements["COMPUTE_FFT"] = source.str();
replacements["LOOP_REQUIRED"] = (loopRequired ? "1" : "0");
replacements["SIGN"] = (forward ? "1" : "-1");
replacements["INPUT_TYPE"] = (inputIsReal && axis == 0 && forward ? "real" : "real2");
replacements["OUTPUT_TYPE"] = (outputIsReal ? "real" : "real2");
replacements["INPUT_IS_REAL"] = (inputIsReal && axis == 0 && forward ? "1" : "0");
replacements["INPUT_IS_PACKED"] = (inputIsReal && axis == 0 && !forward ? "1" : "0");
replacements["OUTPUT_IS_PACKED"] = (outputIsPacked ? "1" : "0");
CUmodule module = context.createModule(CudaKernelSources::vectorOps+context.replaceStrings(CudaKernelSources::fft, replacements));
CUfunction kernel = context.getKernel(module, "execFFT");
threads = (isCPU ? 1 : blocksPerGroup*threadsPerBlock);
return kernel;
}
......@@ -1497,6 +1497,8 @@ CudaCalcNonbondedForceKernel::~CudaCalcNonbondedForceKernel() {
delete pmeAtomGridIndex;
if (sort != NULL)
delete sort;
if (fft != NULL)
delete fft;
if (pmeio != NULL)
delete pmeio;
if (hasInitializedFFT) {
......@@ -1643,9 +1645,9 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
int gridSizeX, gridSizeY, gridSizeZ;
NonbondedForceImpl::calcPMEParameters(system, force, alpha, gridSizeX, gridSizeY, gridSizeZ);
gridSizeX = findFFTDimension(gridSizeX);
gridSizeY = findFFTDimension(gridSizeY);
gridSizeZ = findFFTDimension(gridSizeZ);
gridSizeX = CudaFFT3D::findLegalDimension(gridSizeX);
gridSizeY = CudaFFT3D::findLegalDimension(gridSizeY);
gridSizeZ = CudaFFT3D::findLegalDimension(gridSizeZ);
defines["EWALD_ALPHA"] = cu.doubleToString(alpha);
defines["TWO_OVER_SQRT_PI"] = cu.doubleToString(2.0/sqrt(M_PI));
......@@ -1704,6 +1706,7 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
pmeAtomRange = CudaArray::create<int>(cu, gridSizeX*gridSizeY*gridSizeZ+1, "pmeAtomRange");
pmeAtomGridIndex = CudaArray::create<int2>(cu, numParticles, "pmeAtomGridIndex");
sort = new CudaSort(cu, new SortTrait(), cu.getNumAtoms());
fft = new CudaFFT3D(cu, gridSizeX, gridSizeY, gridSizeZ, true);
cufftResult result = cufftPlan3d(&fftForward, gridSizeX, gridSizeY, gridSizeZ, cu.getUseDoublePrecision() ? CUFFT_D2Z : CUFFT_R2C);
if (result != CUFFT_SUCCESS)
......@@ -1719,7 +1722,7 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
int cufftVersion;
cufftGetVersion(&cufftVersion);
usePmeStream = (cu.getComputeCapability() < 5.0 && numParticles < 130000 && cufftVersion >= 6000 && cufftVersion != 7000); // Workarounds for various CUDA bugs
usePmeStream = true;//(cu.getComputeCapability() < 5.0 && numParticles < 130000 && cufftVersion >= 6000 && cufftVersion != 7000); // Workarounds for various CUDA bugs
if (usePmeStream) {
cuStreamCreate(&pmeStream, CU_STREAM_NON_BLOCKING);
cufftSetStream(fftForward, pmeStream);
......@@ -1893,10 +1896,11 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF
cu.executeKernel(pmeFinishSpreadChargeKernel, finishSpreadArgs, directPmeGrid->getSize());
}
if (cu.getUseDoublePrecision())
cufftExecD2Z(fftForward, (double*) directPmeGrid->getDevicePointer(), (double2*) reciprocalPmeGrid->getDevicePointer());
else
cufftExecR2C(fftForward, (float*) directPmeGrid->getDevicePointer(), (float2*) reciprocalPmeGrid->getDevicePointer());
// if (cu.getUseDoublePrecision())
// cufftExecD2Z(fftForward, (double*) directPmeGrid->getDevicePointer(), (double2*) reciprocalPmeGrid->getDevicePointer());
// else
// cufftExecR2C(fftForward, (float*) directPmeGrid->getDevicePointer(), (float2*) reciprocalPmeGrid->getDevicePointer());
fft->execFFT(*directPmeGrid, *reciprocalPmeGrid, true);
if (includeEnergy) {
void* computeEnergyArgs[] = {&reciprocalPmeGrid->getDevicePointer(), &cu.getEnergyBuffer().getDevicePointer(),
......@@ -1910,10 +1914,11 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF
cu.getPeriodicBoxSizePointer(), recipBoxVectorPointer[0], recipBoxVectorPointer[1], recipBoxVectorPointer[2]};
cu.executeKernel(pmeConvolutionKernel, convolutionArgs, cu.getNumAtoms());
if (cu.getUseDoublePrecision())
cufftExecZ2D(fftBackward, (double2*) reciprocalPmeGrid->getDevicePointer(), (double*) directPmeGrid->getDevicePointer());
else
cufftExecC2R(fftBackward, (float2*) reciprocalPmeGrid->getDevicePointer(), (float*) directPmeGrid->getDevicePointer());
// if (cu.getUseDoublePrecision())
// cufftExecZ2D(fftBackward, (double2*) reciprocalPmeGrid->getDevicePointer(), (double*) directPmeGrid->getDevicePointer());
// else
// cufftExecC2R(fftBackward, (float2*) reciprocalPmeGrid->getDevicePointer(), (float*) directPmeGrid->getDevicePointer());
fft->execFFT(*reciprocalPmeGrid, *directPmeGrid, false);
void* interpolateArgs[] = {&cu.getPosq().getDevicePointer(), &cu.getForce().getDevicePointer(), &directPmeGrid->getDevicePointer(),
......
static __inline__ __device__ real2 multiplyComplex(real2 c1, real2 c2) {
return make_real2(c1.x*c2.x-c1.y*c2.y, c1.x*c2.y+c1.y*c2.x);
}
/**
* Load a value from the half-complex grid produces by a real-to-complex transform.
*/
static __inline__ __device__ real2 loadComplexValue(const real2* __restrict__ in, int x, int y, int z) {
const int inputZSize = ZSIZE/2+1;
if (z < inputZSize)
return in[x*YSIZE*inputZSize+y*inputZSize+z];
int xp = (x == 0 ? 0 : XSIZE-x);
int yp = (y == 0 ? 0 : YSIZE-y);
real2 value = in[xp*YSIZE*inputZSize+yp*inputZSize+(ZSIZE-z)];
return make_real2(value.x, -value.y);
}
/**
* Perform a 1D FFT on each row along one axis.
*/
extern "C" __global__ void execFFT(const INPUT_TYPE* __restrict__ in, OUTPUT_TYPE* __restrict__ out) {
__shared__ real2 w[ZSIZE];
__shared__ real2 data0[BLOCKS_PER_GROUP*ZSIZE];
__shared__ real2 data1[BLOCKS_PER_GROUP*ZSIZE];
for (int i = threadIdx.x; i < ZSIZE; i += blockDim.x)
w[i] = make_real2(cos(-(SIGN)*i*2*M_PI/ZSIZE), sin(-(SIGN)*i*2*M_PI/ZSIZE));
__syncthreads();
const int block = threadIdx.x/THREADS_PER_BLOCK;
for (int baseIndex = blockIdx.x*BLOCKS_PER_GROUP; baseIndex < XSIZE*YSIZE; baseIndex += gridDim.x*BLOCKS_PER_GROUP) {
int index = baseIndex+block;
int x = index/YSIZE;
int y = index-x*YSIZE;
#if OUTPUT_IS_PACKED
if (x < XSIZE/2+1) {
#endif
//#if LOOP_REQUIRED
if (index < XSIZE*YSIZE)
for (int i = threadIdx.x-block*THREADS_PER_BLOCK; i < ZSIZE; i += THREADS_PER_BLOCK)
#if INPUT_IS_REAL
data0[i+block*ZSIZE] = make_real2(in[x*(YSIZE*ZSIZE)+y*ZSIZE+i], 0);
#elif INPUT_IS_PACKED
data0[i+block*ZSIZE] = loadComplexValue(in, x, y, i);
#else
data0[i+block*ZSIZE] = in[x*(YSIZE*ZSIZE)+y*ZSIZE+i];
#endif
//#else
// if (index < XSIZE*YSIZE && (threadIdx.x%BLOCK_SIZE) < ZSIZE)
// #if INPUT_IS_REAL
// data0[threadIdx.x] = make_real2(in[x*(YSIZE*ZSIZE)+y*ZSIZE+threadIdx.x%BLOCK_SIZE], 0);
// #elif INPUT_IS_PACKED
// data0[threadIdx.x] = loadComplexValue(in, x, y, threadIdx.x%BLOCK_SIZE);
// #else
// data0[threadIdx.x] = in[x*(YSIZE*ZSIZE)+y*ZSIZE+threadIdx.x%BLOCK_SIZE];
// #endif
//#endif
#if OUTPUT_IS_PACKED
}
#endif
__syncthreads();
COMPUTE_FFT
}
}
/**
* Combine the two halves of a real grid into a complex grid that is half as large.
*/
extern "C" __global__ void packForwardData(const real* __restrict__ in, real2* __restrict__ out) {
const int gridSize = PACKED_XSIZE*PACKED_YSIZE*PACKED_ZSIZE;
for (int index = blockIdx.x*blockDim.x+threadIdx.x; index < gridSize; index += blockDim.x*gridDim.x) {
int x = index/(PACKED_YSIZE*PACKED_ZSIZE);
int remainder = index-x*(PACKED_YSIZE*PACKED_ZSIZE);
int y = remainder/PACKED_ZSIZE;
int z = remainder-y*PACKED_ZSIZE;
#if PACKED_AXIS == 0
real2 value = make_real2(in[2*x*YSIZE*ZSIZE+y*ZSIZE+z], in[(2*x+1)*YSIZE*ZSIZE+y*ZSIZE+z]);
#elif PACKED_AXIS == 1
real2 value = make_real2(in[x*YSIZE*ZSIZE+2*y*ZSIZE+z], in[x*YSIZE*ZSIZE+(2*y+1)*ZSIZE+z]);
#else
real2 value = make_real2(in[x*YSIZE*ZSIZE+y*ZSIZE+2*z], in[x*YSIZE*ZSIZE+y*ZSIZE+(2*z+1)]);
#endif
out[index] = value;
}
}
/**
* Split the transformed data back into a full sized, symmetric grid.
*/
extern "C" __global__ void unpackForwardData(const real2* __restrict__ in, real2* __restrict__ out) {
// Compute the phase factors.
#if PACKED_AXIS == 0
__shared__ real2 w[PACKED_XSIZE];
for (int i = threadIdx.x; i < PACKED_XSIZE; i += blockDim.x)
w[i] = make_real2(sin(i*2*M_PI/XSIZE), cos(i*2*M_PI/XSIZE));
#elif PACKED_AXIS == 1
__shared__ real2 w[PACKED_YSIZE];
for (int i = threadIdx.x; i < PACKED_YSIZE; i += blockDim.x)
w[i] = make_real2(sin(i*2*M_PI/YSIZE), cos(i*2*M_PI/YSIZE));
#else
__shared__ real2 w[PACKED_ZSIZE];
for (int i = threadIdx.x; i < PACKED_ZSIZE; i += blockDim.x)
w[i] = make_real2(sin(i*2*M_PI/ZSIZE), cos(i*2*M_PI/ZSIZE));
#endif
__syncthreads();
// Transform the data.
const int gridSize = PACKED_XSIZE*PACKED_YSIZE*PACKED_ZSIZE;
const int outputZSize = ZSIZE/2+1;
for (int index = blockIdx.x*blockDim.x+threadIdx.x; index < gridSize; index += blockDim.x*gridDim.x) {
int x = index/(PACKED_YSIZE*PACKED_ZSIZE);
int remainder = index-x*(PACKED_YSIZE*PACKED_ZSIZE);
int y = remainder/PACKED_ZSIZE;
int z = remainder-y*PACKED_ZSIZE;
int xp = (x == 0 ? 0 : PACKED_XSIZE-x);
int yp = (y == 0 ? 0 : PACKED_YSIZE-y);
int zp = (z == 0 ? 0 : PACKED_ZSIZE-z);
real2 z1 = in[x*PACKED_YSIZE*PACKED_ZSIZE+y*PACKED_ZSIZE+z];
real2 z2 = in[xp*PACKED_YSIZE*PACKED_ZSIZE+yp*PACKED_ZSIZE+zp];
#if PACKED_AXIS == 0
real2 wfac = w[x];
#elif PACKED_AXIS == 1
real2 wfac = w[y];
#else
real2 wfac = w[z];
#endif
real2 output = make_real2((z1.x+z2.x - wfac.x*(z1.x-z2.x) + wfac.y*(z1.y+z2.y))/2, (z1.y-z2.y - wfac.y*(z1.x-z2.x) - wfac.x*(z1.y+z2.y))/2);
if (z < outputZSize)
out[x*YSIZE*outputZSize+y*outputZSize+z] = output;
xp = (x == 0 ? 0 : XSIZE-x);
yp = (y == 0 ? 0 : YSIZE-y);
zp = (z == 0 ? 0 : ZSIZE-z);
if (zp < outputZSize) {
#if PACKED_AXIS == 0
if (x == 0)
out[PACKED_XSIZE*YSIZE*outputZSize+yp*outputZSize+zp] = make_real2((z1.x-z1.y+z2.x-z2.y)/2, (-z1.x-z1.y+z2.x+z2.y)/2);
#elif PACKED_AXIS == 1
if (y == 0)
out[xp*YSIZE*outputZSize+PACKED_YSIZE*outputZSize+zp] = make_real2((z1.x-z1.y+z2.x-z2.y)/2, (-z1.x-z1.y+z2.x+z2.y)/2);
#else
if (z == 0)
out[xp*YSIZE*outputZSize+yp*outputZSize+PACKED_ZSIZE] = make_real2((z1.x-z1.y+z2.x-z2.y)/2, (-z1.x-z1.y+z2.x+z2.y)/2);
#endif
else
out[xp*YSIZE*outputZSize+yp*outputZSize+zp] = make_real2(output.x, -output.y);
}
}
}
/**
* Load a value from the half-complex grid produced by a real-to-complex transform.
*/
static __inline__ __device__ real2 loadComplexValue(const real2* __restrict__ in, int x, int y, int z) {
const int inputZSize = ZSIZE/2+1;
if (z < inputZSize)
return in[x*YSIZE*inputZSize+y*inputZSize+z];
int xp = (x == 0 ? 0 : XSIZE-x);
int yp = (y == 0 ? 0 : YSIZE-y);
real2 value = in[xp*YSIZE*inputZSize+yp*inputZSize+(ZSIZE-z)];
return make_real2(value.x, -value.y);
}
/**
* Repack the symmetric complex grid into one half as large in preparation for doing an inverse complex-to-real transform.
*/
extern "C" __global__ void packBackwardData(const real2* __restrict__ in, real2* __restrict__ out) {
// Compute the phase factors.
#if PACKED_AXIS == 0
__shared__ real2 w[PACKED_XSIZE];
for (int i = threadIdx.x; i < PACKED_XSIZE; i += blockDim.x)
w[i] = make_real2(cos(i*2*M_PI/XSIZE), sin(i*2*M_PI/XSIZE));
#elif PACKED_AXIS == 1
__shared__ real2 w[PACKED_YSIZE];
for (int i = threadIdx.x; i < PACKED_YSIZE; i += blockDim.x)
w[i] = make_real2(cos(i*2*M_PI/YSIZE), sin(i*2*M_PI/YSIZE));
#else
__shared__ real2 w[PACKED_ZSIZE];
for (int i = threadIdx.x; i < PACKED_ZSIZE; i += blockDim.x)
w[i] = make_real2(cos(i*2*M_PI/ZSIZE), sin(i*2*M_PI/ZSIZE));
#endif
__syncthreads();
// Transform the data.
const int gridSize = PACKED_XSIZE*PACKED_YSIZE*PACKED_ZSIZE;
for (int index = blockIdx.x*blockDim.x+threadIdx.x; index < gridSize; index += blockDim.x*gridDim.x) {
int x = index/(PACKED_YSIZE*PACKED_ZSIZE);
int remainder = index-x*(PACKED_YSIZE*PACKED_ZSIZE);
int y = remainder/PACKED_ZSIZE;
int z = remainder-y*PACKED_ZSIZE;
int xp = (x == 0 ? 0 : PACKED_XSIZE-x);
int yp = (y == 0 ? 0 : PACKED_YSIZE-y);
int zp = (z == 0 ? 0 : PACKED_ZSIZE-z);
real2 z1 = loadComplexValue(in, x, y, z);
#if PACKED_AXIS == 0
real2 wfac = w[x];
real2 z2 = loadComplexValue(in, PACKED_XSIZE-x, yp, zp);
#elif PACKED_AXIS == 1
real2 wfac = w[y];
real2 z2 = loadComplexValue(in, xp, PACKED_YSIZE-y, zp);
#else
real2 wfac = w[z];
real2 z2 = loadComplexValue(in, xp, yp, PACKED_ZSIZE-z);
#endif
real2 even = make_real2((z1.x+z2.x)/2, (z1.y-z2.y)/2);
real2 odd = make_real2((z1.x-z2.x)/2, (z1.y+z2.y)/2);
odd = make_real2(odd.x*wfac.x-odd.y*wfac.y, odd.y*wfac.x+odd.x*wfac.y);
out[x*PACKED_YSIZE*PACKED_ZSIZE+y*PACKED_ZSIZE+z] = make_real2(even.x-odd.y, even.y+odd.x);
}
}
/**
* Split the data back into a full sized, real grid after an inverse transform.
*/
extern "C" __global__ void unpackBackwardData(const real2* __restrict__ in, real* __restrict__ out) {
const int gridSize = PACKED_XSIZE*PACKED_YSIZE*PACKED_ZSIZE;
for (int index = blockIdx.x*blockDim.x+threadIdx.x; index < gridSize; index += blockDim.x*gridDim.x) {
int x = index/(PACKED_YSIZE*PACKED_ZSIZE);
int remainder = index-x*(PACKED_YSIZE*PACKED_ZSIZE);
int y = remainder/PACKED_ZSIZE;
int z = remainder-y*PACKED_ZSIZE;
real2 value = 2*in[index];
#if PACKED_AXIS == 0
out[2*x*YSIZE*ZSIZE+y*ZSIZE+z] = value.x;
out[(2*x+1)*YSIZE*ZSIZE+y*ZSIZE+z] = value.y;
#elif PACKED_AXIS == 1
out[x*YSIZE*ZSIZE+2*y*ZSIZE+z] = value.x;
out[x*YSIZE*ZSIZE+(2*y+1)*ZSIZE+z] = value.y;
#else
out[x*YSIZE*ZSIZE+y*ZSIZE+2*z] = value.x;
out[x*YSIZE*ZSIZE+y*ZSIZE+(2*z+1)] = value.y;
#endif
}
}
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2011-2015 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
/**
* This tests the CUDA implementation of sorting.
*/
#include "openmm/internal/AssertionUtilities.h"
#include "CudaArray.h"
#include "CudaContext.h"
#include "CudaFFT3D.h"
#include "CudaSort.h"
#include "fftpack.h"
#include "sfmt/SFMT.h"
#include "openmm/System.h"
#include <iostream>
#include <cmath>
#include <set>
using namespace OpenMM;
using namespace std;
static CudaPlatform platform;
template <class Real2>
void testTransform(bool realToComplex, int xsize, int ysize, int zsize) {
System system;
system.addParticle(0.0);
CudaPlatform::PlatformData platformData(NULL, system, "", "true", platform.getPropertyDefaultValue("CudaPrecision"), "false",
platform.getPropertyDefaultValue(CudaPlatform::CudaCompiler()), platform.getPropertyDefaultValue(CudaPlatform::CudaTempDirectory()),
platform.getPropertyDefaultValue(CudaPlatform::CudaHostCompiler()));
CudaContext& context = *platformData.contexts[0];
context.initialize();
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
vector<Real2> original(xsize*ysize*zsize);
vector<t_complex> reference(original.size());
for (int i = 0; i < (int) original.size(); i++) {
Real2 value;
value.x = (float) genrand_real2(sfmt);
value.y = (float) genrand_real2(sfmt);
original[i] = value;
reference[i] = t_complex(value.x, value.y);
}
for (int i = 0; i < (int) reference.size(); i++) {
if (realToComplex)
reference[i] = t_complex(i%2 == 0 ? original[i/2].x : original[i/2].y, 0);
else
reference[i] = t_complex(original[i].x, original[i].y);
}
CudaArray grid1(context, original.size(), sizeof(Real2), "grid1");
CudaArray grid2(context, original.size(), sizeof(Real2), "grid2");
grid1.upload(original);
CudaFFT3D fft(context, xsize, ysize, zsize, realToComplex);
// Perform a forward FFT, then verify the result is correct.
fft.execFFT(grid1, grid2, true);
vector<Real2> result;
grid2.download(result);
fftpack_t plan;
fftpack_init_3d(&plan, xsize, ysize, zsize);
fftpack_exec_3d(plan, FFTPACK_FORWARD, &reference[0], &reference[0]);
int outputZSize = (realToComplex ? zsize/2+1 : zsize);
for (int x = 0; x < xsize; x++)
for (int y = 0; y < ysize; y++)
for (int z = 0; z < outputZSize; z++) {
int index1 = x*ysize*zsize + y*zsize + z;
int index2 = x*ysize*outputZSize + y*outputZSize + z;
ASSERT_EQUAL_TOL(reference[index1].re, result[index2].x, 1e-3);
ASSERT_EQUAL_TOL(reference[index1].im, result[index2].y, 1e-3);
}
fftpack_destroy(plan);
// Perform a backward transform and see if we get the original values.
fft.execFFT(grid2, grid1, false);
grid1.download(result);
double scale = 1.0/(xsize*ysize*zsize);
int valuesToCheck = (realToComplex ? original.size()/2 : original.size());
for (int i = 0; i < valuesToCheck; ++i) {
ASSERT_EQUAL_TOL(original[i].x, scale*result[i].x, 1e-4);
ASSERT_EQUAL_TOL(original[i].y, scale*result[i].y, 1e-4);
}
}
int main(int argc, char* argv[]) {
try {
if (argc > 1)
platform.setPropertyDefaultValue("CudaPrecision", string(argv[1]));
if (platform.getPropertyDefaultValue("CudaPrecision") == "double") {
testTransform<double2>(false, 28, 25, 30);
testTransform<double2>(true, 28, 25, 25);
testTransform<double2>(true, 25, 28, 25);
testTransform<double2>(true, 25, 25, 28);
testTransform<double2>(true, 21, 25, 27);
}
else {
testTransform<float2>(false, 28, 25, 30);
testTransform<float2>(true, 28, 25, 25);
testTransform<float2>(true, 25, 28, 25);
testTransform<float2>(true, 25, 25, 28);
testTransform<float2>(true, 21, 25, 27);
}
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment