Unverified Commit 01e99e77 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Uniform interface for FFTs (#4911)

* Unified interface for FFTs

* AMOEBA uses unified interface for FFTs

* HIP implementation of common FFT interface
parent a3909c8e
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "openmm/common/ComputeForceInfo.h" #include "openmm/common/ComputeForceInfo.h"
#include "openmm/common/ComputeProgram.h" #include "openmm/common/ComputeProgram.h"
#include "openmm/common/ComputeVectorTypes.h" #include "openmm/common/ComputeVectorTypes.h"
#include "openmm/common/FFT3D.h"
#include "openmm/common/IntegrationUtilities.h" #include "openmm/common/IntegrationUtilities.h"
#include "openmm/common/NonbondedUtilities.h" #include "openmm/common/NonbondedUtilities.h"
#include "openmm/Vec3.h" #include "openmm/Vec3.h"
...@@ -474,6 +475,16 @@ public: ...@@ -474,6 +475,16 @@ public:
* when it is no longer needed. * when it is no longer needed.
*/ */
virtual NonbondedUtilities* createNonbondedUtilities() = 0; virtual NonbondedUtilities* createNonbondedUtilities() = 0;
/**
* Create an object for performing 3D FFTs. The caller is responsible for deleting
* the object when it is no longer needed.
*
* @param xsize the first dimension of the data sets on which FFTs will be performed
* @param ysize the second dimension of the data sets on which FFTs will be performed
* @param zsize the third dimension of the data sets on which FFTs will be performed
* @param realToComplex if true, a real-to-complex transform will be done. Otherwise, it is complex-to-complex.
*/
virtual FFT3D* createFFT(int xsize, int ysize, int zsize, bool realToComplex=false) = 0;
/** /**
* Get the smallest legal size for a dimension of the grid. * Get the smallest legal size for a dimension of the grid.
*/ */
......
#ifndef __OPENMM_FFT3D_H__
#define __OPENMM_FFT3D_H__
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "openmm/common/ArrayInterface.h"
namespace OpenMM {
/**
* This class defines a uniform API for three dimensional Fast Fourier Transforms.
* Each platform provides its own implementation. Instances can be created by
* calling createFFT() on a ComputeContext.
*
* FFTs tend to be most efficient when the size of each dimension is a product of
* small prime factors. You can call findLegalFFTDimension() on the ComputeContext
* to determine the smallest size that satisfies this requirement and is greater
* than or equal to a specified minimum size.
*
* Note that this class performs an unnormalized transform. That means that if you perform
* a forward transform followed immediately by an inverse transform, the effect is to
* multiply every value of the original data set by the total number of data points.
*/
class OPENMM_EXPORT_COMMON FFT3D {
public:
virtual ~FFT3D() {
}
/**
* Perform a Fourier transform. The transform cannot be done in-place: the input and output
* arrays must be different. Also, the input array is used as workspace, so its contents
* are destroyed. This also means that both arrays must be large enough to hold complex values,
* even when performing a real-to-complex transform.
*
* When performing a real-to-complex transform, the output data is of size xsize*ysize*(zsize/2+1)
* and contains only the non-redundant elements.
*
* @param in the data to transform, ordered such that in[x*ysize*zsize + y*zsize + z] contains element (x, y, z)
* @param out on exit, this contains the transformed data
* @param forward true to perform a forward transform, false to perform an inverse transform
*/
virtual void execFFT(ArrayInterface& in, ArrayInterface& out, bool forward=true) = 0;
};
} // namespace OpenMM
#endif // __OPENMM_FFT3D_H__
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2023 Stanford University and the Authors. * * Portions copyright (c) 2009-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
#include "CudaArray.h" #include "CudaArray.h"
#include "CudaBondedUtilities.h" #include "CudaBondedUtilities.h"
#include "CudaExpressionUtilities.h" #include "CudaExpressionUtilities.h"
#include "CudaFFT3D.h"
#include "CudaIntegrationUtilities.h" #include "CudaIntegrationUtilities.h"
#include "CudaNonbondedUtilities.h" #include "CudaNonbondedUtilities.h"
#include "CudaPlatform.h" #include "CudaPlatform.h"
...@@ -508,6 +509,16 @@ public: ...@@ -508,6 +509,16 @@ public:
CudaNonbondedUtilities* createNonbondedUtilities() { CudaNonbondedUtilities* createNonbondedUtilities() {
return new CudaNonbondedUtilities(*this); return new CudaNonbondedUtilities(*this);
} }
/**
* Create an object for performing 3D FFTs. The caller is responsible for deleting
* the object when it is no longer needed.
*
* @param xsize the first dimension of the data sets on which FFTs will be performed
* @param ysize the second dimension of the data sets on which FFTs will be performed
* @param zsize the third dimension of the data sets on which FFTs will be performed
* @param realToComplex if true, a real-to-complex transform will be done. Otherwise, it is complex-to-complex.
*/
CudaFFT3D* createFFT(int xsize, int ysize, int zsize, bool realToComplex=false);
/** /**
* This should be called by the Integrator from its own initialize() method. * This should be called by the Integrator from its own initialize() method.
* It ensures all contexts are fully initialized. * It ensures all contexts are fully initialized.
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. * * Portions copyright (c) 2009-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -27,35 +27,33 @@ ...@@ -27,35 +27,33 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>. * * along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
#include "CudaArray.h" #include "openmm/common/FFT3D.h"
#include "openmm/common/ArrayInterface.h"
#include <cuda.h>
#include <cufft.h>
namespace OpenMM { namespace OpenMM {
class CudaContext;
/** /**
* This class performs three dimensional Fast Fourier Transforms. It is based on the * This class performs three dimensional Fast Fourier Transforms. It is implemented
* mixed radix algorithm described in * using cuFFT.
* <p> *
* Takahashi, D. and Kanada, Y., "High-Performance Radix-2, 3 and 5 Parallel 1-D Complex * FFTs tend to be most efficient when the size of each dimension is a product of
* FFT Algorithms for Distributed-Memory Parallel Computers." Journal of Supercomputing, * small prime factors. You can call findLegalFFTDimension() on the ComputeContext
* 15, 207–228 (2000). * to determine the smallest size that satisfies this requirement and is greater
* <p> * than or equal to a specified minimum size.
* This class places certain restrictions on the allowed dimensions of the grid. First, *
* the size of each dimension may have no prime factors other than 2, 3, 5, and 7. You
* can call findLegalDimension() to determine the smallest size that satisfies this
* requirement and is greater than or equal to a specified minimum size. Second, the size
* of each dimension must be small enough to compute each 1D transform entirely in local
* memory with one work unit per data point. This will vary between platforms, but is
* typically at least 512.
* <p>
* Note that this class performs an unnormalized transform. That means that if you perform * Note that this class performs an unnormalized transform. That means that if you perform
* a forward transform followed immediately by an inverse transform, the effect is to * a forward transform followed immediately by an inverse transform, the effect is to
* multiply every value of the original data set by the total number of data points. * multiply every value of the original data set by the total number of data points.
*/ */
class OPENMM_EXPORT_COMMON CudaFFT3D { class OPENMM_EXPORT_COMMON CudaFFT3D : public FFT3D {
public: public:
/** /**
* Create an CudaFFT3D object for performing transforms of a particular size. * Create a CudaFFT3D object for performing transforms of a particular size.
* *
* @param context the context in which to perform calculations * @param context the context in which to perform calculations
* @param xsize the first dimension of the data sets on which FFTs will be performed * @param xsize the first dimension of the data sets on which FFTs will be performed
...@@ -64,12 +62,17 @@ public: ...@@ -64,12 +62,17 @@ public:
* @param realToComplex if true, a real-to-complex transform will be done. Otherwise, it is complex-to-complex. * @param realToComplex if true, a real-to-complex transform will be done. Otherwise, it is complex-to-complex.
*/ */
CudaFFT3D(CudaContext& context, int xsize, int ysize, int zsize, bool realToComplex=false); CudaFFT3D(CudaContext& context, int xsize, int ysize, int zsize, bool realToComplex=false);
~CudaFFT3D();
/**
* Set the stream to perform the FFT on.
*/
void setStream(CUstream stream);
/** /**
* Perform a Fourier transform. The transform cannot be done in-place: the input and output * Perform a Fourier transform. The transform cannot be done in-place: the input and output
* arrays must be different. Also, the input array is used as workspace, so its contents * arrays must be different. Also, the input array is used as workspace, so its contents
* are destroyed. This also means that both arrays must be large enough to hold complex values, * are destroyed. This also means that both arrays must be large enough to hold complex values,
* even when performing a real-to-complex transform. * even when performing a real-to-complex transform.
* <p> *
* When performing a real-to-complex transform, the output data is of size xsize*ysize*(zsize/2+1) * When performing a real-to-complex transform, the output data is of size xsize*ysize*(zsize/2+1)
* and contains only the non-redundant elements. * and contains only the non-redundant elements.
* *
...@@ -77,23 +80,12 @@ public: ...@@ -77,23 +80,12 @@ public:
* @param out on exit, this contains the transformed data * @param out on exit, this contains the transformed data
* @param forward true to perform a forward transform, false to perform an inverse transform * @param forward true to perform a forward transform, false to perform an inverse transform
*/ */
void execFFT(CudaArray& in, CudaArray& out, bool forward = true); void execFFT(ArrayInterface& in, ArrayInterface& out, bool forward = true);
/**
* Get the smallest legal size for a dimension of the grid (that is, a size with no prime
* factors other than 2, 3, 5, and 7).
*
* @param minimum the minimum size the return value must be greater than or equal to
*/
static int findLegalDimension(int minimum);
private: private:
CUfunction createKernel(int xsize, int ysize, int zsize, int& threads, int axis, bool forward, bool inputIsReal);
int xsize, ysize, zsize;
int xthreads, ythreads, zthreads;
bool packRealAsComplex;
CudaContext& context; CudaContext& context;
CUfunction xkernel, ykernel, zkernel; cufftHandle fftForward;
CUfunction invxkernel, invykernel, invzkernel; cufftHandle fftBackward;
CUfunction packForwardKernel, unpackForwardKernel, packBackwardKernel, unpackBackwardKernel; bool realToComplex, hasInitialized;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -35,7 +35,6 @@ ...@@ -35,7 +35,6 @@
#include "openmm/kernels.h" #include "openmm/kernels.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/common/CommonKernels.h" #include "openmm/common/CommonKernels.h"
#include <cufft.h>
namespace OpenMM { namespace OpenMM {
...@@ -189,11 +188,7 @@ private: ...@@ -189,11 +188,7 @@ private:
CUstream pmeStream; CUstream pmeStream;
CUevent pmeSyncEvent, paramsSyncEvent; CUevent pmeSyncEvent, paramsSyncEvent;
CudaFFT3D* fft; CudaFFT3D* fft;
cufftHandle fftForward;
cufftHandle fftBackward;
CudaFFT3D* dispersionFft; CudaFFT3D* dispersionFft;
cufftHandle dispersionFftForward;
cufftHandle dispersionFftBackward;
CUfunction computeParamsKernel, computeExclusionParamsKernel, computePlasmaCorrectionKernel; CUfunction computeParamsKernel, computeExclusionParamsKernel, computePlasmaCorrectionKernel;
CUfunction ewaldSumsKernel; CUfunction ewaldSumsKernel;
CUfunction ewaldForcesKernel; CUfunction ewaldForcesKernel;
...@@ -217,7 +212,7 @@ private: ...@@ -217,7 +212,7 @@ private:
int interpolateForceThreads; int interpolateForceThreads;
int gridSizeX, gridSizeY, gridSizeZ; int gridSizeX, gridSizeY, gridSizeZ;
int dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ; int dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ;
bool hasCoulomb, hasLJ, useFixedPointChargeSpreading, usePmeStream, useCudaFFT, doLJPME, usePosqCharges, recomputeParams, hasOffsets; bool hasCoulomb, hasLJ, useFixedPointChargeSpreading, usePmeStream, doLJPME, usePosqCharges, recomputeParams, hasOffsets;
NonbondedMethod nonbondedMethod; NonbondedMethod nonbondedMethod;
static const int PmeOrder = 5; static const int PmeOrder = 5;
}; };
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2024 Stanford University and the Authors. * * Portions copyright (c) 2009-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -437,6 +437,10 @@ void CudaContext::initializeContexts() { ...@@ -437,6 +437,10 @@ void CudaContext::initializeContexts() {
getPlatformData().initializeContexts(system); getPlatformData().initializeContexts(system);
} }
CudaFFT3D* CudaContext::createFFT(int xsize, int ysize, int zsize, bool realToComplex) {
return new CudaFFT3D(*this, xsize, ysize, zsize, realToComplex);
}
void CudaContext::setAsCurrent() { void CudaContext::setAsCurrent() {
if (contextIsValid) if (contextIsValid)
cuCtxSetCurrent(context); cuCtxSetCurrent(context);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. * * Portions copyright (c) 2009-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -26,326 +26,81 @@ ...@@ -26,326 +26,81 @@
#include "CudaFFT3D.h" #include "CudaFFT3D.h"
#include "CudaContext.h" #include "CudaContext.h"
#include "CudaKernelSources.h"
#include "SimTKOpenMMRealType.h"
#include <map>
#include <sstream>
#include <string>
using namespace OpenMM; using namespace OpenMM;
using namespace std;
CudaFFT3D::CudaFFT3D(CudaContext& context, int xsize, int ysize, int zsize, bool realToComplex) : CudaFFT3D::CudaFFT3D(CudaContext& context, int xsize, int ysize, int zsize, bool realToComplex) :
context(context), xsize(xsize), ysize(ysize), zsize(zsize) { context(context), realToComplex(realToComplex), hasInitialized(false) {
packRealAsComplex = false; cufftType type1, type2;
int packedXSize = xsize;
int packedYSize = ysize;
int packedZSize = zsize;
if (realToComplex) { if (realToComplex) {
// If any axis size is even, we can pack the real values into a complex grid that is only half as large. if (context.getUseDoublePrecision()) {
// Look for an appropriate axis. type1 = CUFFT_D2Z;
type2 = CUFFT_Z2D;
packRealAsComplex = true;
int packedAxis, bufferSize;
if (xsize%2 == 0) {
packedAxis = 0;
packedXSize /= 2;
bufferSize = packedXSize;
} }
else if (ysize%2 == 0) { else {
packedAxis = 1; type1 = CUFFT_R2C;
packedYSize /= 2; type2 = CUFFT_C2R;
bufferSize = packedYSize;
} }
else if (zsize%2 == 0) {
packedAxis = 2;
packedZSize /= 2;
bufferSize = packedZSize;
}
else
packRealAsComplex = false;
if (packRealAsComplex) {
// Build the kernels for packing and unpacking the data.
map<string, string> defines;
defines["XSIZE"] = context.intToString(xsize);
defines["YSIZE"] = context.intToString(ysize);
defines["ZSIZE"] = context.intToString(zsize);
defines["PACKED_AXIS"] = context.intToString(packedAxis);
defines["PACKED_XSIZE"] = context.intToString(packedXSize);
defines["PACKED_YSIZE"] = context.intToString(packedYSize);
defines["PACKED_ZSIZE"] = context.intToString(packedZSize);
defines["M_PI"] = context.doubleToString(M_PI);
CUmodule module = context.createModule(CudaKernelSources::vectorOps+CudaKernelSources::fftR2C, defines);
packForwardKernel = context.getKernel(module, "packForwardData");
unpackForwardKernel = context.getKernel(module, "unpackForwardData");
packBackwardKernel = context.getKernel(module, "packBackwardData");
unpackBackwardKernel = context.getKernel(module, "unpackBackwardData");
}
}
bool inputIsReal = (realToComplex && !packRealAsComplex);
zkernel = createKernel(packedXSize, packedYSize, packedZSize, zthreads, 0, true, inputIsReal);
xkernel = createKernel(packedYSize, packedZSize, packedXSize, xthreads, 1, true, inputIsReal);
ykernel = createKernel(packedZSize, packedXSize, packedYSize, ythreads, 2, true, inputIsReal);
invzkernel = createKernel(packedXSize, packedYSize, packedZSize, zthreads, 0, false, inputIsReal);
invxkernel = createKernel(packedYSize, packedZSize, packedXSize, xthreads, 1, false, inputIsReal);
invykernel = createKernel(packedZSize, packedXSize, packedYSize, ythreads, 2, false, inputIsReal);
}
void CudaFFT3D::execFFT(CudaArray& in, CudaArray& out, bool forward) {
CUfunction kernel1 = (forward ? zkernel : invzkernel);
CUfunction kernel2 = (forward ? xkernel : invxkernel);
CUfunction kernel3 = (forward ? ykernel : invykernel);
void* args1[] = {&in.getDevicePointer(), &out.getDevicePointer()};
void* args2[] = {&out.getDevicePointer(), &in.getDevicePointer()};
if (packRealAsComplex) {
CUfunction packKernel = (forward ? packForwardKernel : packBackwardKernel);
CUfunction unpackKernel = (forward ? unpackForwardKernel : unpackBackwardKernel);
int gridSize = xsize*ysize*zsize/2;
// Pack the data into a half sized grid.
context.executeKernel(packKernel, args1, gridSize, 128);
// Perform the FFT.
context.executeKernel(kernel1, args2, gridSize, zthreads);
context.executeKernel(kernel2, args1, gridSize, xthreads);
context.executeKernel(kernel3, args2, gridSize, ythreads);
// Unpack the data.
context.executeKernel(unpackKernel, args1, gridSize, 128);
} }
else { else {
context.executeKernel(kernel1, args1, xsize*ysize*zsize, zthreads); if (context.getUseDoublePrecision())
context.executeKernel(kernel2, args2, xsize*ysize*zsize, xthreads); type1 = type2 = CUFFT_Z2Z;
context.executeKernel(kernel3, args1, xsize*ysize*zsize, ythreads); else
type1 = type2 = CUFFT_C2C;
} }
cufftResult result = cufftPlan3d(&fftForward, xsize, ysize, zsize, type1);
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error initializing FFT: "+context.intToString(result));
result = cufftPlan3d(&fftBackward, xsize, ysize, zsize, type2);
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error initializing FFT: "+context.intToString(result));
hasInitialized = true;
} }
int CudaFFT3D::findLegalDimension(int minimum) { CudaFFT3D::~CudaFFT3D() {
if (minimum < 1) if (hasInitialized) {
return 1; cufftDestroy(fftForward);
while (true) { cufftDestroy(fftBackward);
// Attempt to factor the current value.
int unfactored = minimum;
for (int factor = 2; factor < 8; factor++) {
while (unfactored > 1 && unfactored%factor == 0)
unfactored /= factor;
}
if (unfactored == 1)
return minimum;
minimum++;
} }
} }
void CudaFFT3D::setStream(CUstream stream) {
static int getSmallestRadix(int size) { cufftSetStream(fftForward, stream);
int minRadix = 1; cufftSetStream(fftBackward, stream);
int unfactored = size;
while (unfactored%7 == 0) {
minRadix = 7;
unfactored /= 7;
}
while (unfactored%5 == 0) {
minRadix = 5;
unfactored /= 5;
}
while (unfactored%4 == 0) {
minRadix = 4;
unfactored /= 4;
}
while (unfactored%3 == 0) {
minRadix = 3;
unfactored /= 3;
}
while (unfactored%2 == 0) {
minRadix = 2;
unfactored /= 2;
}
return minRadix;
} }
CUfunction CudaFFT3D::createKernel(int xsize, int ysize, int zsize, int& threads, int axis, bool forward, bool inputIsReal) { void CudaFFT3D::execFFT(ArrayInterface& in, ArrayInterface& out, bool forward) {
int maxThreads = (context.getUseDoublePrecision() ? 128 : 256); CUdeviceptr in2 = context.unwrap(in).getDevicePointer();
// while (maxThreads > 128 && maxThreads-64 >= zsize) CUdeviceptr out2 = context.unwrap(out).getDevicePointer();
// maxThreads -= 64; cufftResult result;
int threadsPerBlock = zsize/getSmallestRadix(zsize); if (forward) {
stringstream source; if (realToComplex) {
int blocksPerGroup = max(1, maxThreads/threadsPerBlock); if (context.getUseDoublePrecision())
int stage = 0; result = cufftExecD2Z(fftForward, (double*) in2, (double2*) out2);
int L = zsize; else
int m = 1; result = cufftExecR2C(fftForward, (float*) in2, (float2*) out2);
// Factor zsize, generating an appropriate block of code for each factor.
while (L > 1) {
int input = stage%2;
int output = 1-input;
int radix;
if (L%7 == 0)
radix = 7;
else if (L%5 == 0)
radix = 5;
else if (L%4 == 0)
radix = 4;
else if (L%3 == 0)
radix = 3;
else if (L%2 == 0)
radix = 2;
else
throw OpenMMException("Illegal size for FFT: "+context.intToString(zsize));
source<<"{\n";
L = L/radix;
source<<"// Pass "<<(stage+1)<<" (radix "<<radix<<")\n";
if (L*m < threadsPerBlock)
source<<"if (threadIdx.x < "<<(blocksPerGroup*L*m)<<") {\n";
else
source<<"{\n";
source<<"int block = threadIdx.x/"<<(L*m)<<";\n";
source<<"int i = threadIdx.x-block*"<<(L*m)<<";\n";
source<<"int base = i+block*"<<zsize<<";\n";
source<<"int j = i/"<<m<<";\n";
if (radix == 7) {
source<<"real2 c0 = data"<<input<<"[base];\n";
source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n";
source<<"real2 c2 = data"<<input<<"[base+"<<(2*L*m)<<"];\n";
source<<"real2 c3 = data"<<input<<"[base+"<<(3*L*m)<<"];\n";
source<<"real2 c4 = data"<<input<<"[base+"<<(4*L*m)<<"];\n";
source<<"real2 c5 = data"<<input<<"[base+"<<(5*L*m)<<"];\n";
source<<"real2 c6 = data"<<input<<"[base+"<<(6*L*m)<<"];\n";
source<<"real2 d0 = c1+c6;\n";
source<<"real2 d1 = c1-c6;\n";
source<<"real2 d2 = c2+c5;\n";
source<<"real2 d3 = c2-c5;\n";
source<<"real2 d4 = c4+c3;\n";
source<<"real2 d5 = c4-c3;\n";
source<<"real2 d6 = d2+d0;\n";
source<<"real2 d7 = d5+d3;\n";
source<<"real2 b0 = c0+d6+d4;\n";
source<<"real2 b1 = "<<context.doubleToString((cos(2*M_PI/7)+cos(4*M_PI/7)+cos(6*M_PI/7))/3-1)<<"*(d6+d4);\n";
source<<"real2 b2 = "<<context.doubleToString((2*cos(2*M_PI/7)-cos(4*M_PI/7)-cos(6*M_PI/7))/3)<<"*(d0-d4);\n";
source<<"real2 b3 = "<<context.doubleToString((cos(2*M_PI/7)-2*cos(4*M_PI/7)+cos(6*M_PI/7))/3)<<"*(d4-d2);\n";
source<<"real2 b4 = "<<context.doubleToString((cos(2*M_PI/7)+cos(4*M_PI/7)-2*cos(6*M_PI/7))/3)<<"*(d2-d0);\n";
source<<"real2 b5 = -(SIGN)*"<<context.doubleToString((sin(2*M_PI/7)+sin(4*M_PI/7)-sin(6*M_PI/7))/3)<<"*(d7+d1);\n";
source<<"real2 b6 = -(SIGN)*"<<context.doubleToString((2*sin(2*M_PI/7)-sin(4*M_PI/7)+sin(6*M_PI/7))/3)<<"*(d1-d5);\n";
source<<"real2 b7 = -(SIGN)*"<<context.doubleToString((sin(2*M_PI/7)-2*sin(4*M_PI/7)-sin(6*M_PI/7))/3)<<"*(d5-d3);\n";
source<<"real2 b8 = -(SIGN)*"<<context.doubleToString((sin(2*M_PI/7)+sin(4*M_PI/7)+2*sin(6*M_PI/7))/3)<<"*(d3-d1);\n";
source<<"real2 t0 = b0+b1;\n";
source<<"real2 t1 = b2+b3;\n";
source<<"real2 t2 = b4-b3;\n";
source<<"real2 t3 = -b2-b4;\n";
source<<"real2 t4 = b6+b7;\n";
source<<"real2 t5 = b8-b7;\n";
source<<"real2 t6 = -b8-b6;\n";
source<<"real2 t7 = t0+t1;\n";
source<<"real2 t8 = t0+t2;\n";
source<<"real2 t9 = t0+t3;\n";
source<<"real2 t10 = make_real2(t4.y+b5.y, -(t4.x+b5.x));\n";
source<<"real2 t11 = make_real2(t5.y+b5.y, -(t5.x+b5.x));\n";
source<<"real2 t12 = make_real2(t6.y+b5.y, -(t6.x+b5.x));\n";
source<<"data"<<output<<"[base+6*j*"<<m<<"] = b0;\n";
source<<"data"<<output<<"[base+(6*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(7*L)<<"], t7-t10);\n";
source<<"data"<<output<<"[base+(6*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(7*L)<<"], t9-t12);\n";
source<<"data"<<output<<"[base+(6*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(7*L)<<"], t8+t11);\n";
source<<"data"<<output<<"[base+(6*j+4)*"<<m<<"] = multiplyComplex(w[j*"<<(4*zsize)<<"/"<<(7*L)<<"], t8-t11);\n";
source<<"data"<<output<<"[base+(6*j+5)*"<<m<<"] = multiplyComplex(w[j*"<<(5*zsize)<<"/"<<(7*L)<<"], t9+t12);\n";
source<<"data"<<output<<"[base+(6*j+6)*"<<m<<"] = multiplyComplex(w[j*"<<(6*zsize)<<"/"<<(7*L)<<"], t7+t10);\n";
}
else if (radix == 5) {
source<<"real2 c0 = data"<<input<<"[base];\n";
source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n";
source<<"real2 c2 = data"<<input<<"[base+"<<(2*L*m)<<"];\n";
source<<"real2 c3 = data"<<input<<"[base+"<<(3*L*m)<<"];\n";
source<<"real2 c4 = data"<<input<<"[base+"<<(4*L*m)<<"];\n";
source<<"real2 d0 = c1+c4;\n";
source<<"real2 d1 = c2+c3;\n";
source<<"real2 d2 = "<<context.doubleToString(sin(0.4*M_PI))<<"*(c1-c4);\n";
source<<"real2 d3 = "<<context.doubleToString(sin(0.4*M_PI))<<"*(c2-c3);\n";
source<<"real2 d4 = d0+d1;\n";
source<<"real2 d5 = "<<context.doubleToString(0.25*sqrt(5.0))<<"*(d0-d1);\n";
source<<"real2 d6 = c0-0.25f*d4;\n";
source<<"real2 d7 = d6+d5;\n";
source<<"real2 d8 = d6-d5;\n";
string coeff = context.doubleToString(sin(0.2*M_PI)/sin(0.4*M_PI));
source<<"real2 d9 = (SIGN)*make_real2(d2.y+"<<coeff<<"*d3.y, -d2.x-"<<coeff<<"*d3.x);\n";
source<<"real2 d10 = (SIGN)*make_real2("<<coeff<<"*d2.y-d3.y, d3.x-"<<coeff<<"*d2.x);\n";
source<<"data"<<output<<"[base+4*j*"<<m<<"] = c0+d4;\n";
source<<"data"<<output<<"[base+(4*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(5*L)<<"], d7+d9);\n";
source<<"data"<<output<<"[base+(4*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(5*L)<<"], d8+d10);\n";
source<<"data"<<output<<"[base+(4*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(5*L)<<"], d8-d10);\n";
source<<"data"<<output<<"[base+(4*j+4)*"<<m<<"] = multiplyComplex(w[j*"<<(4*zsize)<<"/"<<(5*L)<<"], d7-d9);\n";
} }
else if (radix == 4) { else {
source<<"real2 c0 = data"<<input<<"[base];\n"; if (context.getUseDoublePrecision())
source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n"; result = cufftExecZ2Z(fftForward, (double2*) in2, (double2*) out2, CUFFT_FORWARD);
source<<"real2 c2 = data"<<input<<"[base+"<<(2*L*m)<<"];\n"; else
source<<"real2 c3 = data"<<input<<"[base+"<<(3*L*m)<<"];\n"; result = cufftExecC2C(fftForward, (float2*) in2, (float2*) out2, CUFFT_FORWARD);
source<<"real2 d0 = c0+c2;\n";
source<<"real2 d1 = c0-c2;\n";
source<<"real2 d2 = c1+c3;\n";
source<<"real2 d3 = (SIGN)*make_real2(c1.y-c3.y, c3.x-c1.x);\n";
source<<"data"<<output<<"[base+3*j*"<<m<<"] = d0+d2;\n";
source<<"data"<<output<<"[base+(3*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(4*L)<<"], d1+d3);\n";
source<<"data"<<output<<"[base+(3*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(4*L)<<"], d0-d2);\n";
source<<"data"<<output<<"[base+(3*j+3)*"<<m<<"] = multiplyComplex(w[j*"<<(3*zsize)<<"/"<<(4*L)<<"], d1-d3);\n";
} }
else if (radix == 3) { }
source<<"real2 c0 = data"<<input<<"[base];\n"; else {
source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n"; if (realToComplex) {
source<<"real2 c2 = data"<<input<<"[base+"<<(2*L*m)<<"];\n"; if (context.getUseDoublePrecision())
source<<"real2 d0 = c1+c2;\n"; result = cufftExecZ2D(fftBackward, (double2*) in2, (double*) out2);
source<<"real2 d1 = c0-0.5f*d0;\n"; else
source<<"real2 d2 = (SIGN)*"<<context.doubleToString(sin(M_PI/3.0))<<"*make_real2(c1.y-c2.y, c2.x-c1.x);\n"; result = cufftExecC2R(fftBackward, (float2*) in2, (float*) out2);
source<<"data"<<output<<"[base+2*j*"<<m<<"] = c0+d0;\n";
source<<"data"<<output<<"[base+(2*j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(3*L)<<"], d1+d2);\n";
source<<"data"<<output<<"[base+(2*j+2)*"<<m<<"] = multiplyComplex(w[j*"<<(2*zsize)<<"/"<<(3*L)<<"], d1-d2);\n";
} }
else if (radix == 2) { else {
source<<"real2 c0 = data"<<input<<"[base];\n"; if (context.getUseDoublePrecision())
source<<"real2 c1 = data"<<input<<"[base+"<<(L*m)<<"];\n"; result = cufftExecZ2Z(fftBackward, (double2*) in2, (double2*) out2, CUFFT_INVERSE);
source<<"data"<<output<<"[base+j*"<<m<<"] = c0+c1;\n"; else
source<<"data"<<output<<"[base+(j+1)*"<<m<<"] = multiplyComplex(w[j*"<<zsize<<"/"<<(2*L)<<"], c0-c1);\n"; result = cufftExecC2C(fftBackward, (float2*) in2, (float2*) out2, CUFFT_INVERSE);
} }
source<<"}\n";
m = m*radix;
source<<"__syncthreads();\n";
source<<"}\n";
++stage;
} }
if (result != CUFFT_SUCCESS)
// Create the kernel. throw OpenMMException("Error executing FFT: "+context.intToString(result));
bool outputIsReal = (inputIsReal && axis == 2 && !forward);
bool outputIsPacked = (inputIsReal && axis == 2 && forward);
string outputSuffix = (outputIsReal ? ".x" : "");
if (outputIsPacked)
source<<"if (index < XSIZE*YSIZE && x < XSIZE/2+1)\n";
else
source<<"if (index < XSIZE*YSIZE)\n";
source<<"for (int i = threadIdx.x-block*THREADS_PER_BLOCK; i < ZSIZE; i += THREADS_PER_BLOCK)\n";
if (outputIsPacked)
source<<"out[y*(ZSIZE*(XSIZE/2+1))+i*(XSIZE/2+1)+x] = data"<<(stage%2)<<"[i+block*ZSIZE]"<<outputSuffix<<";\n";
else
source<<"out[y*(ZSIZE*XSIZE)+i*XSIZE+x] = data"<<(stage%2)<<"[i+block*ZSIZE]"<<outputSuffix<<";\n";
map<string, string> replacements;
replacements["XSIZE"] = context.intToString(xsize);
replacements["YSIZE"] = context.intToString(ysize);
replacements["ZSIZE"] = context.intToString(zsize);
replacements["BLOCKS_PER_GROUP"] = context.intToString(blocksPerGroup);
replacements["THREADS_PER_BLOCK"] = context.intToString(threadsPerBlock);
replacements["M_PI"] = context.doubleToString(M_PI);
replacements["COMPUTE_FFT"] = source.str();
replacements["SIGN"] = (forward ? "1" : "-1");
replacements["INPUT_TYPE"] = (inputIsReal && axis == 0 && forward ? "real" : "real2");
replacements["OUTPUT_TYPE"] = (outputIsReal ? "real" : "real2");
replacements["INPUT_IS_REAL"] = (inputIsReal && axis == 0 && forward ? "1" : "0");
replacements["INPUT_IS_PACKED"] = (inputIsReal && axis == 0 && !forward ? "1" : "0");
replacements["OUTPUT_IS_PACKED"] = (outputIsPacked ? "1" : "0");
CUmodule module = context.createModule(CudaKernelSources::vectorOps+context.replaceStrings(CudaKernelSources::fft, replacements));
CUfunction kernel = context.getKernel(module, "execFFT");
threads = blocksPerGroup*threadsPerBlock;
return kernel;
} }
...@@ -256,21 +256,8 @@ CudaCalcNonbondedForceKernel::~CudaCalcNonbondedForceKernel() { ...@@ -256,21 +256,8 @@ CudaCalcNonbondedForceKernel::~CudaCalcNonbondedForceKernel() {
delete dispersionFft; delete dispersionFft;
if (pmeio != NULL) if (pmeio != NULL)
delete pmeio; delete pmeio;
if (hasInitializedFFT) { if (hasInitializedFFT && usePmeStream)
if (useCudaFFT) { cuStreamDestroy(pmeStream);
cufftDestroy(fftForward);
cufftDestroy(fftBackward);
if (doLJPME) {
cufftDestroy(dispersionFftForward);
cufftDestroy(dispersionFftBackward);
}
}
if (usePmeStream) {
cuStreamDestroy(pmeStream);
cuEventDestroy(pmeSyncEvent);
cuEventDestroy(paramsSyncEvent);
}
}
} }
void CudaCalcNonbondedForceKernel::initialize(const System& system, const NonbondedForce& force) { void CudaCalcNonbondedForceKernel::initialize(const System& system, const NonbondedForce& force) {
...@@ -421,15 +408,15 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon ...@@ -421,15 +408,15 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
// Compute the PME parameters. // Compute the PME parameters.
NonbondedForceImpl::calcPMEParameters(system, force, alpha, gridSizeX, gridSizeY, gridSizeZ, false); NonbondedForceImpl::calcPMEParameters(system, force, alpha, gridSizeX, gridSizeY, gridSizeZ, false);
gridSizeX = CudaFFT3D::findLegalDimension(gridSizeX); gridSizeX = cu.findLegalFFTDimension(gridSizeX);
gridSizeY = CudaFFT3D::findLegalDimension(gridSizeY); gridSizeY = cu.findLegalFFTDimension(gridSizeY);
gridSizeZ = CudaFFT3D::findLegalDimension(gridSizeZ); gridSizeZ = cu.findLegalFFTDimension(gridSizeZ);
if (doLJPME) { if (doLJPME) {
NonbondedForceImpl::calcPMEParameters(system, force, dispersionAlpha, dispersionGridSizeX, NonbondedForceImpl::calcPMEParameters(system, force, dispersionAlpha, dispersionGridSizeX,
dispersionGridSizeY, dispersionGridSizeZ, true); dispersionGridSizeY, dispersionGridSizeZ, true);
dispersionGridSizeX = CudaFFT3D::findLegalDimension(dispersionGridSizeX); dispersionGridSizeX = cu.findLegalFFTDimension(dispersionGridSizeX);
dispersionGridSizeY = CudaFFT3D::findLegalDimension(dispersionGridSizeY); dispersionGridSizeY = cu.findLegalFFTDimension(dispersionGridSizeY);
dispersionGridSizeZ = CudaFFT3D::findLegalDimension(dispersionGridSizeZ); dispersionGridSizeZ = cu.findLegalFFTDimension(dispersionGridSizeZ);
} }
defines["EWALD_ALPHA"] = cu.doubleToString(alpha); defines["EWALD_ALPHA"] = cu.doubleToString(alpha);
...@@ -550,45 +537,17 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon ...@@ -550,45 +537,17 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
pmeEnergyBuffer.initialize(cu, cu.getNumThreadBlocks()*CudaContext::ThreadBlockSize, energyElementSize, "pmeEnergyBuffer"); pmeEnergyBuffer.initialize(cu, cu.getNumThreadBlocks()*CudaContext::ThreadBlockSize, energyElementSize, "pmeEnergyBuffer");
cu.clearBuffer(pmeEnergyBuffer); cu.clearBuffer(pmeEnergyBuffer);
sort = new CudaSort(cu, new SortTrait(), cu.getNumAtoms()); sort = new CudaSort(cu, new SortTrait(), cu.getNumAtoms());
int cufftVersion; fft = cu.createFFT(gridSizeX, gridSizeY, gridSizeZ, true);
cufftGetVersion(&cufftVersion); if (doLJPME)
useCudaFFT = (cufftVersion >= 7050); // There was a critical bug in version 7.0 dispersionFft = cu.createFFT(dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true);
if (useCudaFFT) {
cufftResult result = cufftPlan3d(&fftForward, gridSizeX, gridSizeY, gridSizeZ, cu.getUseDoublePrecision() ? CUFFT_D2Z : CUFFT_R2C);
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error initializing FFT: "+cu.intToString(result));
result = cufftPlan3d(&fftBackward, gridSizeX, gridSizeY, gridSizeZ, cu.getUseDoublePrecision() ? CUFFT_Z2D : CUFFT_C2R);
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error initializing FFT: "+cu.intToString(result));
if (doLJPME) {
result = cufftPlan3d(&dispersionFftForward, dispersionGridSizeX, dispersionGridSizeY,
dispersionGridSizeZ, cu.getUseDoublePrecision() ? CUFFT_D2Z : CUFFT_R2C);
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error initializing disperison FFT: "+cu.intToString(result));
result = cufftPlan3d(&dispersionFftBackward, dispersionGridSizeX, dispersionGridSizeY,
dispersionGridSizeZ, cu.getUseDoublePrecision() ? CUFFT_Z2D : CUFFT_C2R);
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error initializing disperison FFT: "+cu.intToString(result));
}
}
else {
fft = new CudaFFT3D(cu, gridSizeX, gridSizeY, gridSizeZ, true);
if (doLJPME)
dispersionFft = new CudaFFT3D(cu, dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true);
}
// Prepare for doing PME on its own stream. // Prepare for doing PME on its own stream.
if (usePmeStream) { if (usePmeStream) {
cuStreamCreate(&pmeStream, CU_STREAM_NON_BLOCKING); cuStreamCreate(&pmeStream, CU_STREAM_NON_BLOCKING);
if (useCudaFFT) { fft->setStream(pmeStream);
cufftSetStream(fftForward, pmeStream); if (doLJPME)
cufftSetStream(fftBackward, pmeStream); dispersionFft->setStream(pmeStream);
if (doLJPME) {
cufftSetStream(dispersionFftForward, pmeStream);
cufftSetStream(dispersionFftBackward, pmeStream);
}
}
CHECK_RESULT(cuEventCreate(&pmeSyncEvent, cu.getEventFlags()), "Error creating event for NonbondedForce"); CHECK_RESULT(cuEventCreate(&pmeSyncEvent, cu.getEventFlags()), "Error creating event for NonbondedForce");
CHECK_RESULT(cuEventCreate(&paramsSyncEvent, cu.getEventFlags()), "Error creating event for NonbondedForce"); CHECK_RESULT(cuEventCreate(&paramsSyncEvent, cu.getEventFlags()), "Error creating event for NonbondedForce");
int recipForceGroup = force.getReciprocalSpaceForceGroup(); int recipForceGroup = force.getReciprocalSpaceForceGroup();
...@@ -987,20 +946,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF ...@@ -987,20 +946,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF
cu.executeKernel(pmeFinishSpreadChargeKernel, finishSpreadArgs, gridSizeX*gridSizeY*gridSizeZ, 256); cu.executeKernel(pmeFinishSpreadChargeKernel, finishSpreadArgs, gridSizeX*gridSizeY*gridSizeZ, 256);
} }
if (useCudaFFT) { fft->execFFT(pmeGrid1, pmeGrid2, true);
if (cu.getUseDoublePrecision()) {
cufftResult result = cufftExecD2Z(fftForward, (double*) pmeGrid1.getDevicePointer(), (double2*) pmeGrid2.getDevicePointer());
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
} else {
cufftResult result = cufftExecR2C(fftForward, (float*) pmeGrid1.getDevicePointer(), (float2*) pmeGrid2.getDevicePointer());
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
}
}
else {
fft->execFFT(pmeGrid1, pmeGrid2, true);
}
if (includeEnergy) { if (includeEnergy) {
void* computeEnergyArgs[] = {&pmeGrid2.getDevicePointer(), usePmeStream ? &pmeEnergyBuffer.getDevicePointer() : &cu.getEnergyBuffer().getDevicePointer(), void* computeEnergyArgs[] = {&pmeGrid2.getDevicePointer(), usePmeStream ? &pmeEnergyBuffer.getDevicePointer() : &cu.getEnergyBuffer().getDevicePointer(),
...@@ -1014,20 +960,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF ...@@ -1014,20 +960,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF
recipBoxVectorPointer[0], recipBoxVectorPointer[1], recipBoxVectorPointer[2]}; recipBoxVectorPointer[0], recipBoxVectorPointer[1], recipBoxVectorPointer[2]};
cu.executeKernel(pmeConvolutionKernel, convolutionArgs, gridSizeX*gridSizeY*gridSizeZ, 256); cu.executeKernel(pmeConvolutionKernel, convolutionArgs, gridSizeX*gridSizeY*gridSizeZ, 256);
if (useCudaFFT) { fft->execFFT(pmeGrid2, pmeGrid1, false);
if (cu.getUseDoublePrecision()) {
cufftResult result = cufftExecZ2D(fftBackward, (double2*) pmeGrid2.getDevicePointer(), (double*) pmeGrid1.getDevicePointer());
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
} else {
cufftResult result = cufftExecC2R(fftBackward, (float2*) pmeGrid2.getDevicePointer(), (float*) pmeGrid1.getDevicePointer());
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
}
}
else {
fft->execFFT(pmeGrid2, pmeGrid1, false);
}
void* interpolateArgs[] = {&cu.getPosq().getDevicePointer(), &cu.getForce().getDevicePointer(), &pmeGrid1.getDevicePointer(), cu.getPeriodicBoxSizePointer(), void* interpolateArgs[] = {&cu.getPosq().getDevicePointer(), &cu.getForce().getDevicePointer(), &pmeGrid1.getDevicePointer(), cu.getPeriodicBoxSizePointer(),
cu.getInvPeriodicBoxSizePointer(), cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer(), cu.getInvPeriodicBoxSizePointer(), cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer(),
...@@ -1059,20 +992,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF ...@@ -1059,20 +992,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF
cu.executeKernel(pmeDispersionFinishSpreadChargeKernel, finishSpreadArgs, dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256); cu.executeKernel(pmeDispersionFinishSpreadChargeKernel, finishSpreadArgs, dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256);
} }
if (useCudaFFT) { dispersionFft->execFFT(pmeGrid1, pmeGrid2, true);
if (cu.getUseDoublePrecision()) {
cufftResult result = cufftExecD2Z(dispersionFftForward, (double*) pmeGrid1.getDevicePointer(), (double2*) pmeGrid2.getDevicePointer());
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
} else {
cufftResult result = cufftExecR2C(dispersionFftForward, (float*) pmeGrid1.getDevicePointer(), (float2*) pmeGrid2.getDevicePointer());
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
}
}
else {
dispersionFft->execFFT(pmeGrid1, pmeGrid2, true);
}
if (includeEnergy) { if (includeEnergy) {
void* computeEnergyArgs[] = {&pmeGrid2.getDevicePointer(), usePmeStream ? &pmeEnergyBuffer.getDevicePointer() : &cu.getEnergyBuffer().getDevicePointer(), void* computeEnergyArgs[] = {&pmeGrid2.getDevicePointer(), usePmeStream ? &pmeEnergyBuffer.getDevicePointer() : &cu.getEnergyBuffer().getDevicePointer(),
...@@ -1086,20 +1006,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF ...@@ -1086,20 +1006,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF
recipBoxVectorPointer[0], recipBoxVectorPointer[1], recipBoxVectorPointer[2]}; recipBoxVectorPointer[0], recipBoxVectorPointer[1], recipBoxVectorPointer[2]};
cu.executeKernel(pmeDispersionConvolutionKernel, convolutionArgs, dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256); cu.executeKernel(pmeDispersionConvolutionKernel, convolutionArgs, dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256);
if (useCudaFFT) { dispersionFft->execFFT(pmeGrid2, pmeGrid1, false);
if (cu.getUseDoublePrecision()) {
cufftResult result = cufftExecZ2D(dispersionFftBackward, (double2*) pmeGrid2.getDevicePointer(), (double*) pmeGrid1.getDevicePointer());
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
} else {
cufftResult result = cufftExecC2R(dispersionFftBackward, (float2*) pmeGrid2.getDevicePointer(), (float*) pmeGrid1.getDevicePointer());
if (result != CUFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
}
}
else {
dispersionFft->execFFT(pmeGrid2, pmeGrid1, false);
}
void* interpolateArgs[] = {&cu.getPosq().getDevicePointer(), &cu.getForce().getDevicePointer(), &pmeGrid1.getDevicePointer(), cu.getPeriodicBoxSizePointer(), void* interpolateArgs[] = {&cu.getPosq().getDevicePointer(), &cu.getForce().getDevicePointer(), &pmeGrid1.getDevicePointer(), cu.getPeriodicBoxSizePointer(),
cu.getInvPeriodicBoxSizePointer(), cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer(), cu.getInvPeriodicBoxSizePointer(), cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer(),
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2024 Stanford University and the Authors. * * Portions copyright (c) 2009-2025 Stanford University and the Authors. *
* Portions copyright (c) 2020-2023 Advanced Micro Devices, Inc. * * Portions copyright (c) 2020-2023 Advanced Micro Devices, Inc. *
* Authors: Peter Eastman, Nicholas Curtis * * Authors: Peter Eastman, Nicholas Curtis *
* Contributors: * * Contributors: *
...@@ -51,6 +51,7 @@ ...@@ -51,6 +51,7 @@
#include "HipArray.h" #include "HipArray.h"
#include "HipBondedUtilities.h" #include "HipBondedUtilities.h"
#include "HipExpressionUtilities.h" #include "HipExpressionUtilities.h"
#include "HipFFT3D.h"
#include "HipIntegrationUtilities.h" #include "HipIntegrationUtilities.h"
#include "HipNonbondedUtilities.h" #include "HipNonbondedUtilities.h"
#include "HipPlatform.h" #include "HipPlatform.h"
...@@ -182,9 +183,19 @@ public: ...@@ -182,9 +183,19 @@ public:
*/ */
ComputeEvent createEvent(); ComputeEvent createEvent();
/** /**
* Get the smallest legal size for a dimension of the grid supported by the FFT. * Create an object for performing 3D FFTs. The caller is responsible for deleting
* the object when it is no longer needed.
*
* @param xsize the first dimension of the data sets on which FFTs will be performed
* @param ysize the second dimension of the data sets on which FFTs will be performed
* @param zsize the third dimension of the data sets on which FFTs will be performed
* @param realToComplex if true, a real-to-complex transform will be done. Otherwise, it is complex-to-complex.
*/
HipFFT3D* createFFT(int xsize, int ysize, int zsize, bool realToComplex=false);
/**
* Get the smallest legal size for a dimension of the grid.
*/ */
virtual int findLegalFFTDimension(int minimum); int findLegalFFTDimension(int minimum);
/** /**
* Compile source code to create a ComputeProgram. * Compile source code to create a ComputeProgram.
* *
......
...@@ -32,6 +32,8 @@ ...@@ -32,6 +32,8 @@
#define VKFFT_BACKEND 2 // HIP #define VKFFT_BACKEND 2 // HIP
#include "vkFFT.h" #include "vkFFT.h"
#include "openmm/common/FFT3D.h"
#include "openmm/common/ArrayInterface.h"
namespace OpenMM { namespace OpenMM {
...@@ -40,22 +42,22 @@ class HipContext; ...@@ -40,22 +42,22 @@ class HipContext;
/** /**
* This class performs three dimensional Fast Fourier Transforms using VkFFT by * This class performs three dimensional Fast Fourier Transforms using VkFFT by
* Dmitrii Tolmachev (https://github.com/DTolm/VkFFT). * Dmitrii Tolmachev (https://github.com/DTolm/VkFFT).
* <p> *
* Note that this class performs an unnormalized transform. That means that if you perform * Note that this class performs an unnormalized transform. That means that if you perform
* a forward transform followed immediately by an inverse transform, the effect is to * a forward transform followed immediately by an inverse transform, the effect is to
* multiply every value of the original data set by the total number of data points. * multiply every value of the original data set by the total number of data points.
*/ */
class OPENMM_EXPORT_COMMON HipFFT3D { class OPENMM_EXPORT_COMMON HipFFT3D : public FFT3D {
public: public:
/** /**
* Create an HipFFT3D object for performing transforms of a particular size. * Create an HipFFT3D object for performing transforms of a particular size.
* <p> *
* The transform cannot be done in-place: the input and output * The transform cannot be done in-place: the input and output
* arrays must be different. Also, the input array is used as workspace, so its contents * arrays must be different. Also, the input array is used as workspace, so its contents
* are destroyed. This also means that both arrays must be large enough to hold complex values, * are destroyed. This also means that both arrays must be large enough to hold complex values,
* even when performing a real-to-complex transform. * even when performing a real-to-complex transform.
* <p> *
* When performing a real-to-complex transform, the output data is of size xsize*ysize*(zsize/2+1) * When performing a real-to-complex transform, the output data is of size xsize*ysize*(zsize/2+1)
* and contains only the non-redundant elements. * and contains only the non-redundant elements.
* *
...@@ -64,18 +66,23 @@ public: ...@@ -64,18 +66,23 @@ public:
* @param ysize the second dimension of the data sets on which FFTs will be performed * @param ysize the second dimension of the data sets on which FFTs will be performed
* @param zsize the third dimension of the data sets on which FFTs will be performed * @param zsize the third dimension of the data sets on which FFTs will be performed
* @param realToComplex if true, a real-to-complex transform will be done. Otherwise, it is complex-to-complex. * @param realToComplex if true, a real-to-complex transform will be done. Otherwise, it is complex-to-complex.
* @param stream HIP stream
* @param in the data to transform, ordered such that in[x*ysize*zsize + y*zsize + z] contains element (x, y, z)
* @param out on exit, this contains the transformed data
*/ */
HipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool realToComplex, hipStream_t stream, HipArray& in, HipArray& out); HipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool realToComplex);
~HipFFT3D(); ~HipFFT3D();
/** /**
* Perform a Fourier transform. * Perform a Fourier transform. The transform cannot be done in-place: the input and output
* arrays must be different. Also, the input array is used as workspace, so its contents
* are destroyed. This also means that both arrays must be large enough to hold complex values,
* even when performing a real-to-complex transform.
*
* When performing a real-to-complex transform, the output data is of size xsize*ysize*(zsize/2+1)
* and contains only the non-redundant elements.
* *
* @param in the data to transform, ordered such that in[x*ysize*zsize + y*zsize + z] contains element (x, y, z)
* @param out on exit, this contains the transformed data
* @param forward true to perform a forward transform, false to perform an inverse transform * @param forward true to perform a forward transform, false to perform an inverse transform
*/ */
void execFFT(bool forward); void execFFT(ArrayInterface& in, ArrayInterface& out, bool forward=true);
/** /**
* Get the smallest legal size for a dimension of the grid (that is, a size with no prime * Get the smallest legal size for a dimension of the grid (that is, a size with no prime
* factors other than 2, 3, 5, 7, 11, 13). VkFFT supports arbitrary sizes but they may work * factors other than 2, 3, 5, 7, 11, 13). VkFFT supports arbitrary sizes but they may work
......
...@@ -696,6 +696,10 @@ ComputeEvent HipContext::createEvent() { ...@@ -696,6 +696,10 @@ ComputeEvent HipContext::createEvent() {
return shared_ptr<ComputeEventImpl>(new HipEvent(*this)); return shared_ptr<ComputeEventImpl>(new HipEvent(*this));
} }
HipFFT3D* HipContext::createFFT(int xsize, int ysize, int zsize, bool realToComplex) {
return new HipFFT3D(*this, xsize, ysize, zsize, realToComplex);
}
int HipContext::findLegalFFTDimension(int minimum) { int HipContext::findLegalFFTDimension(int minimum) {
return HipFFT3D::findLegalDimension(minimum); return HipFFT3D::findLegalDimension(minimum);
} }
......
...@@ -35,12 +35,8 @@ ...@@ -35,12 +35,8 @@
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
HipFFT3D::HipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool realToComplex, hipStream_t stream, HipArray& in, HipArray& out) : HipFFT3D::HipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool realToComplex) : context(context) {
context(context), stream(stream) {
deviceIndex = context.getDeviceIndex(); deviceIndex = context.getDeviceIndex();
inputBuffer = in.getDevicePointer();
outputBuffer = out.getDevicePointer();
size_t valueSize = context.getUseDoublePrecision() ? sizeof(double) : sizeof(float); size_t valueSize = context.getUseDoublePrecision() ? sizeof(double) : sizeof(float);
inputBufferSize = zsize * ysize * xsize * valueSize; inputBufferSize = zsize * ysize * xsize * valueSize;
if (realToComplex) { if (realToComplex) {
...@@ -54,7 +50,7 @@ HipFFT3D::HipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool re ...@@ -54,7 +50,7 @@ HipFFT3D::HipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool re
configuration.performR2C = realToComplex; configuration.performR2C = realToComplex;
configuration.device = &deviceIndex; configuration.device = &deviceIndex;
configuration.num_streams = 1; configuration.num_streams = 1;
configuration.stream = &this->stream; configuration.stream = &stream;
configuration.doublePrecision = context.getUseDoublePrecision(); configuration.doublePrecision = context.getUseDoublePrecision();
configuration.FFTdim = 3; configuration.FFTdim = 3;
...@@ -133,7 +129,16 @@ HipFFT3D::~HipFFT3D() { ...@@ -133,7 +129,16 @@ HipFFT3D::~HipFFT3D() {
delete app; delete app;
} }
void HipFFT3D::execFFT(bool forward) { void HipFFT3D::execFFT(ArrayInterface& in, ArrayInterface& out, bool forward) {
if (forward) {
inputBuffer = context.unwrap(in).getDevicePointer();
outputBuffer = context.unwrap(out).getDevicePointer();
}
else {
inputBuffer = context.unwrap(out).getDevicePointer();
outputBuffer = context.unwrap(in).getDevicePointer();
}
stream = context.getCurrentStream();
VkFFTResult fftResult = VkFFTAppend(app, forward ? -1 : 1, NULL); VkFFTResult fftResult = VkFFTAppend(app, forward ? -1 : 1, NULL);
if (fftResult != VKFFT_SUCCESS) { if (fftResult != VKFFT_SUCCESS) {
throw OpenMMException("Error executing VkFFTAppend: "+context.intToString(fftResult)); throw OpenMMException("Error executing VkFFTAppend: "+context.intToString(fftResult));
......
...@@ -34,7 +34,6 @@ ...@@ -34,7 +34,6 @@
#include "CommonKernelSources.h" #include "CommonKernelSources.h"
#include "HipBondedUtilities.h" #include "HipBondedUtilities.h"
#include "HipExpressionUtilities.h" #include "HipExpressionUtilities.h"
#include "HipFFT3D.h"
#include "HipIntegrationUtilities.h" #include "HipIntegrationUtilities.h"
#include "HipNonbondedUtilities.h" #include "HipNonbondedUtilities.h"
#include "HipKernelSources.h" #include "HipKernelSources.h"
...@@ -554,9 +553,9 @@ void HipCalcNonbondedForceKernel::initialize(const System& system, const Nonbond ...@@ -554,9 +553,9 @@ void HipCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
} }
hipStream_t fftStream = usePmeStream ? pmeStream : cu.getCurrentStream(); hipStream_t fftStream = usePmeStream ? pmeStream : cu.getCurrentStream();
fft = new HipFFT3D(cu, gridSizeX, gridSizeY, gridSizeZ, true, fftStream, pmeGrid1, pmeGrid2); fft = cu.createFFT(gridSizeX, gridSizeY, gridSizeZ, true);
if (doLJPME) if (doLJPME)
dispersionFft = new HipFFT3D(cu, dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true, fftStream, pmeGrid1, pmeGrid2); dispersionFft = cu.createFFT(dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true);
hasInitializedFFT = true; hasInitializedFFT = true;
// Initialize the b-spline moduli. // Initialize the b-spline moduli.
...@@ -947,7 +946,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -947,7 +946,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
cu.executeKernelFlat(pmeFinishSpreadChargeKernel, finishSpreadArgs, gridSizeX*gridSizeY*gridSizeZ, 256); cu.executeKernelFlat(pmeFinishSpreadChargeKernel, finishSpreadArgs, gridSizeX*gridSizeY*gridSizeZ, 256);
} }
fft->execFFT(true); fft->execFFT(pmeGrid1, pmeGrid2, true);
if (includeEnergy) { if (includeEnergy) {
void* computeEnergyArgs[] = {&pmeGrid2.getDevicePointer(), usePmeStream ? &pmeEnergyBuffer.getDevicePointer() : &cu.getEnergyBuffer().getDevicePointer(), void* computeEnergyArgs[] = {&pmeGrid2.getDevicePointer(), usePmeStream ? &pmeEnergyBuffer.getDevicePointer() : &cu.getEnergyBuffer().getDevicePointer(),
...@@ -961,7 +960,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -961,7 +960,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
recipBoxVectorPointer[0], recipBoxVectorPointer[1], recipBoxVectorPointer[2]}; recipBoxVectorPointer[0], recipBoxVectorPointer[1], recipBoxVectorPointer[2]};
cu.executeKernelFlat(pmeConvolutionKernel, convolutionArgs, gridSizeX*gridSizeY*gridSizeZ, 256); cu.executeKernelFlat(pmeConvolutionKernel, convolutionArgs, gridSizeX*gridSizeY*gridSizeZ, 256);
fft->execFFT(false); fft->execFFT(pmeGrid2, pmeGrid1, false);
void* interpolateArgs[] = {&cu.getPosq().getDevicePointer(), &cu.getForce().getDevicePointer(), &pmeGrid1.getDevicePointer(), cu.getPeriodicBoxSizePointer(), void* interpolateArgs[] = {&cu.getPosq().getDevicePointer(), &cu.getForce().getDevicePointer(), &pmeGrid1.getDevicePointer(), cu.getPeriodicBoxSizePointer(),
cu.getInvPeriodicBoxSizePointer(), cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer(), cu.getInvPeriodicBoxSizePointer(), cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer(),
...@@ -993,7 +992,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -993,7 +992,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
cu.executeKernelFlat(pmeDispersionFinishSpreadChargeKernel, finishSpreadArgs, dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256); cu.executeKernelFlat(pmeDispersionFinishSpreadChargeKernel, finishSpreadArgs, dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256);
} }
dispersionFft->execFFT(true); dispersionFft->execFFT(pmeGrid1, pmeGrid2, true);
if (includeEnergy) { if (includeEnergy) {
void* computeEnergyArgs[] = {&pmeGrid2.getDevicePointer(), usePmeStream ? &pmeEnergyBuffer.getDevicePointer() : &cu.getEnergyBuffer().getDevicePointer(), void* computeEnergyArgs[] = {&pmeGrid2.getDevicePointer(), usePmeStream ? &pmeEnergyBuffer.getDevicePointer() : &cu.getEnergyBuffer().getDevicePointer(),
...@@ -1007,7 +1006,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -1007,7 +1006,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
recipBoxVectorPointer[0], recipBoxVectorPointer[1], recipBoxVectorPointer[2]}; recipBoxVectorPointer[0], recipBoxVectorPointer[1], recipBoxVectorPointer[2]};
cu.executeKernelFlat(pmeDispersionConvolutionKernel, convolutionArgs, dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256); cu.executeKernelFlat(pmeDispersionConvolutionKernel, convolutionArgs, dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256);
dispersionFft->execFFT(false); dispersionFft->execFFT(pmeGrid2, pmeGrid1, false);
void* interpolateArgs[] = {&cu.getPosq().getDevicePointer(), &cu.getForce().getDevicePointer(), &pmeGrid1.getDevicePointer(), cu.getPeriodicBoxSizePointer(), void* interpolateArgs[] = {&cu.getPosq().getDevicePointer(), &cu.getForce().getDevicePointer(), &pmeGrid1.getDevicePointer(), cu.getPeriodicBoxSizePointer(),
cu.getInvPeriodicBoxSizePointer(), cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer(), cu.getInvPeriodicBoxSizePointer(), cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer(),
......
...@@ -88,11 +88,11 @@ void testTransform(bool realToComplex, int xsize, int ysize, int zsize, double e ...@@ -88,11 +88,11 @@ void testTransform(bool realToComplex, int xsize, int ysize, int zsize, double e
HipArray grid1(context, original.size(), sizeof(Real2), "grid1"); HipArray grid1(context, original.size(), sizeof(Real2), "grid1");
HipArray grid2(context, original.size(), sizeof(Real2), "grid2"); HipArray grid2(context, original.size(), sizeof(Real2), "grid2");
grid1.upload(original); grid1.upload(original);
HipFFT3D fft(context, xsize, ysize, zsize, realToComplex, context.getCurrentStream(), grid1, grid2); HipFFT3D fft(context, xsize, ysize, zsize, realToComplex);
// Perform a forward FFT, then verify the result is correct. // Perform a forward FFT, then verify the result is correct.
fft.execFFT(true); fft.execFFT(grid1, grid2, true);
vector<Real2> result; vector<Real2> result;
grid2.download(result); grid2.download(result);
vector<size_t> shape = {(size_t) xsize, (size_t) ysize, (size_t) zsize}; vector<size_t> shape = {(size_t) xsize, (size_t) ysize, (size_t) zsize};
...@@ -113,7 +113,7 @@ void testTransform(bool realToComplex, int xsize, int ysize, int zsize, double e ...@@ -113,7 +113,7 @@ void testTransform(bool realToComplex, int xsize, int ysize, int zsize, double e
// Perform a backward transform and see if we get the original values. // Perform a backward transform and see if we get the original values.
fft.execFFT(false); fft.execFFT(grid2, grid1, false);
grid1.download(result); grid1.download(result);
double scale = 1.0/(xsize*ysize*zsize); double scale = 1.0/(xsize*ysize*zsize);
int valuesToCheck = (realToComplex ? original.size()/2 : original.size()); int valuesToCheck = (realToComplex ? original.size()/2 : original.size());
......
...@@ -53,6 +53,7 @@ ...@@ -53,6 +53,7 @@
#include "OpenCLArray.h" #include "OpenCLArray.h"
#include "OpenCLBondedUtilities.h" #include "OpenCLBondedUtilities.h"
#include "OpenCLExpressionUtilities.h" #include "OpenCLExpressionUtilities.h"
#include "OpenCLFFT3D.h"
#include "OpenCLIntegrationUtilities.h" #include "OpenCLIntegrationUtilities.h"
#include "OpenCLNonbondedUtilities.h" #include "OpenCLNonbondedUtilities.h"
#include "OpenCLPlatform.h" #include "OpenCLPlatform.h"
...@@ -632,6 +633,20 @@ public: ...@@ -632,6 +633,20 @@ public:
OpenCLNonbondedUtilities* createNonbondedUtilities() { OpenCLNonbondedUtilities* createNonbondedUtilities() {
return new OpenCLNonbondedUtilities(*this); return new OpenCLNonbondedUtilities(*this);
} }
/**
* Create an object for performing 3D FFTs. The caller is responsible for deleting
* the object when it is no longer needed.
*
* @param xsize the first dimension of the data sets on which FFTs will be performed
* @param ysize the second dimension of the data sets on which FFTs will be performed
* @param zsize the third dimension of the data sets on which FFTs will be performed
* @param realToComplex if true, a real-to-complex transform will be done. Otherwise, it is complex-to-complex.
*/
OpenCLFFT3D* createFFT(int xsize, int ysize, int zsize, bool realToComplex=false);
/**
* Get the smallest legal size for a dimension of the grid.
*/
int findLegalFFTDimension(int minimum);
/** /**
* This should be called by the Integrator from its own initialize() method. * This should be called by the Integrator from its own initialize() method.
* It ensures all contexts are fully initialized. * It ensures all contexts are fully initialized.
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2023 Stanford University and the Authors. * * Portions copyright (c) 2009-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -34,20 +34,23 @@ ...@@ -34,20 +34,23 @@
#pragma clang diagnostic ignored "-Wdeprecated-declarations" #pragma clang diagnostic ignored "-Wdeprecated-declarations"
#include "vkFFT.h" #include "vkFFT.h"
#endif #endif
#include "OpenCLArray.h" #include "openmm/common/FFT3D.h"
#include "openmm/common/ArrayInterface.h"
namespace OpenMM { namespace OpenMM {
class OpenCLContext;
#ifdef USE_VKFFT #ifdef USE_VKFFT
/** /**
* This class performs three dimensional Fast Fourier Transforms. It uses the * This class performs three dimensional Fast Fourier Transforms. It uses the
* VkFFT library (https://github.com/DTolm/VkFFT). * VkFFT library (https://github.com/DTolm/VkFFT).
* <p> *
* This class is most efficient when the size of each dimension is a product of * This class is most efficient when the size of each dimension is a product of
* small prime factors: 2, 3, 5, 7, 11, and 13. You can call findLegalDimension() * small prime factors: 2, 3, 5, 7, 11, and 13. You can call findLegalDimension()
* to determine the smallest size that satisfies this requirement and is greater * to determine the smallest size that satisfies this requirement and is greater
* than or equal to a specified minimum size. * than or equal to a specified minimum size.
* <p> *
* Note that this class performs an unnormalized transform. That means that if you perform * Note that this class performs an unnormalized transform. That means that if you perform
* a forward transform followed immediately by an inverse transform, the effect is to * a forward transform followed immediately by an inverse transform, the effect is to
* multiply every value of the original data set by the total number of data points. * multiply every value of the original data set by the total number of data points.
...@@ -56,11 +59,11 @@ namespace OpenMM { ...@@ -56,11 +59,11 @@ namespace OpenMM {
/** /**
* This class performs three dimensional Fast Fourier Transforms. It is based on the * This class performs three dimensional Fast Fourier Transforms. It is based on the
* mixed radix algorithm described in * mixed radix algorithm described in
* <p> *
* Takahashi, D. and Kanada, Y., "High-Performance Radix-2, 3 and 5 Parallel 1-D Complex * Takahashi, D. and Kanada, Y., "High-Performance Radix-2, 3 and 5 Parallel 1-D Complex
* FFT Algorithms for Distributed-Memory Parallel Computers." Journal of Supercomputing, * FFT Algorithms for Distributed-Memory Parallel Computers." Journal of Supercomputing,
* 15, 207–228 (2000). * 15, 207–228 (2000).
* <p> *
* This class places certain restrictions on the allowed dimensions of the grid. First, * This class places certain restrictions on the allowed dimensions of the grid. First,
* the size of each dimension may have no prime factors other than 2, 3, 5, and 7. You * the size of each dimension may have no prime factors other than 2, 3, 5, and 7. You
* can call findLegalDimension() to determine the smallest size that satisfies this * can call findLegalDimension() to determine the smallest size that satisfies this
...@@ -68,14 +71,14 @@ namespace OpenMM { ...@@ -68,14 +71,14 @@ namespace OpenMM {
* of each dimension must be small enough to compute each 1D transform entirely in local * of each dimension must be small enough to compute each 1D transform entirely in local
* memory with one work unit per data point. This will vary between platforms, but is * memory with one work unit per data point. This will vary between platforms, but is
* typically at least 512. * typically at least 512.
* <p> *
* Note that this class performs an unnormalized transform. That means that if you perform * Note that this class performs an unnormalized transform. That means that if you perform
* a forward transform followed immediately by an inverse transform, the effect is to * a forward transform followed immediately by an inverse transform, the effect is to
* multiply every value of the original data set by the total number of data points. * multiply every value of the original data set by the total number of data points.
*/ */
#endif #endif
class OPENMM_EXPORT_COMMON OpenCLFFT3D { class OPENMM_EXPORT_COMMON OpenCLFFT3D : public FFT3D {
public: public:
/** /**
* Create an OpenCLFFT3D object for performing transforms of a particular size. * Create an OpenCLFFT3D object for performing transforms of a particular size.
...@@ -95,7 +98,7 @@ public: ...@@ -95,7 +98,7 @@ public:
* arrays must be different. Also, the input array is used as workspace, so its contents * arrays must be different. Also, the input array is used as workspace, so its contents
* are destroyed. This also means that both arrays must be large enough to hold complex values, * are destroyed. This also means that both arrays must be large enough to hold complex values,
* even when performing a real-to-complex transform. * even when performing a real-to-complex transform.
* <p> *
* When performing a real-to-complex transform, the output data is of size xsize*ysize*(zsize/2+1) * When performing a real-to-complex transform, the output data is of size xsize*ysize*(zsize/2+1)
* and contains only the non-redundant elements. * and contains only the non-redundant elements.
* *
...@@ -103,10 +106,10 @@ public: ...@@ -103,10 +106,10 @@ public:
* @param out on exit, this contains the transformed data * @param out on exit, this contains the transformed data
* @param forward true to perform a forward transform, false to perform an inverse transform * @param forward true to perform a forward transform, false to perform an inverse transform
*/ */
void execFFT(OpenCLArray& in, OpenCLArray& out, bool forward = true); void execFFT(ArrayInterface& in, ArrayInterface& out, bool forward=true);
/** /**
* Get the smallest legal size for a dimension of the grid (that is, a size with no prime * Get the smallest legal size for a dimension of the grid (that is, a size with no unsupported
* factors other than 2, 3, 5, and 7). * prime factors).
* *
* @param minimum the minimum size the return value must be greater than or equal to * @param minimum the minimum size the return value must be greater than or equal to
*/ */
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2024 Stanford University and the Authors. * * Portions copyright (c) 2009-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -576,6 +576,14 @@ void OpenCLContext::initializeContexts() { ...@@ -576,6 +576,14 @@ void OpenCLContext::initializeContexts() {
getPlatformData().initializeContexts(system); getPlatformData().initializeContexts(system);
} }
OpenCLFFT3D* OpenCLContext::createFFT(int xsize, int ysize, int zsize, bool realToComplex) {
return new OpenCLFFT3D(*this, xsize, ysize, zsize, realToComplex);
}
int OpenCLContext::findLegalFFTDimension(int minimum) {
return OpenCLFFT3D::findLegalDimension(minimum);
}
void OpenCLContext::addForce(ComputeForceInfo* force) { void OpenCLContext::addForce(ComputeForceInfo* force) {
ComputeContext::addForce(force); ComputeContext::addForce(force);
OpenCLForceInfo* clinfo = dynamic_cast<OpenCLForceInfo*>(force); OpenCLForceInfo* clinfo = dynamic_cast<OpenCLForceInfo*>(force);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2023 Stanford University and the Authors. * * Portions copyright (c) 2009-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -71,15 +71,15 @@ OpenCLFFT3D::~OpenCLFFT3D() { ...@@ -71,15 +71,15 @@ OpenCLFFT3D::~OpenCLFFT3D() {
deleteVkFFT(&app); deleteVkFFT(&app);
} }
void OpenCLFFT3D::execFFT(OpenCLArray& in, OpenCLArray& out, bool forward) { void OpenCLFFT3D::execFFT(ArrayInterface& in, ArrayInterface& out, bool forward) {
VkFFTLaunchParams params = {}; VkFFTLaunchParams params = {};
if (forward) { if (forward) {
params.inputBuffer = &in.getDeviceBuffer()(); params.inputBuffer = &context.unwrap(in).getDeviceBuffer()();
params.buffer = &out.getDeviceBuffer()(); params.buffer = &context.unwrap(out).getDeviceBuffer()();
} }
else { else {
params.inputBuffer = &out.getDeviceBuffer()(); params.inputBuffer = &context.unwrap(out).getDeviceBuffer()();
params.buffer = &in.getDeviceBuffer()(); params.buffer = &context.unwrap(in).getDeviceBuffer()();
} }
params.commandQueue = &context.getQueue()(); params.commandQueue = &context.getQueue()();
VkFFTResult result = VkFFTAppend(&app, forward ? -1 : 1, &params); VkFFTResult result = VkFFTAppend(&app, forward ? -1 : 1, &params);
...@@ -148,7 +148,9 @@ OpenCLFFT3D::OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize ...@@ -148,7 +148,9 @@ OpenCLFFT3D::OpenCLFFT3D(OpenCLContext& context, int xsize, int ysize, int zsize
invykernel = createKernel(packedZSize, packedXSize, packedYSize, ythreads, 2, false, inputIsReal); invykernel = createKernel(packedZSize, packedXSize, packedYSize, ythreads, 2, false, inputIsReal);
} }
void OpenCLFFT3D::execFFT(OpenCLArray& in, OpenCLArray& out, bool forward) { void OpenCLFFT3D::execFFT(ArrayInterface& in, ArrayInterface& out, bool forward) {
OpenCLArray& in2 = context.unwrap(in);
OpenCLArray& out2 = context.unwrap(out);
cl::Kernel kernel1 = (forward ? zkernel : invzkernel); cl::Kernel kernel1 = (forward ? zkernel : invzkernel);
cl::Kernel kernel2 = (forward ? xkernel : invxkernel); cl::Kernel kernel2 = (forward ? xkernel : invxkernel);
cl::Kernel kernel3 = (forward ? ykernel : invykernel); cl::Kernel kernel3 = (forward ? ykernel : invykernel);
...@@ -159,37 +161,37 @@ void OpenCLFFT3D::execFFT(OpenCLArray& in, OpenCLArray& out, bool forward) { ...@@ -159,37 +161,37 @@ void OpenCLFFT3D::execFFT(OpenCLArray& in, OpenCLArray& out, bool forward) {
// Pack the data into a half sized grid. // Pack the data into a half sized grid.
packKernel.setArg<cl::Buffer>(0, in.getDeviceBuffer()); packKernel.setArg<cl::Buffer>(0, in2.getDeviceBuffer());
packKernel.setArg<cl::Buffer>(1, out.getDeviceBuffer()); packKernel.setArg<cl::Buffer>(1, out2.getDeviceBuffer());
context.executeKernel(packKernel, gridSize); context.executeKernel(packKernel, gridSize);
// Perform the FFT. // Perform the FFT.
kernel1.setArg<cl::Buffer>(0, out.getDeviceBuffer()); kernel1.setArg<cl::Buffer>(0, out2.getDeviceBuffer());
kernel1.setArg<cl::Buffer>(1, in.getDeviceBuffer()); kernel1.setArg<cl::Buffer>(1, in2.getDeviceBuffer());
context.executeKernel(kernel1, gridSize, zthreads); context.executeKernel(kernel1, gridSize, zthreads);
kernel2.setArg<cl::Buffer>(0, in.getDeviceBuffer()); kernel2.setArg<cl::Buffer>(0, in2.getDeviceBuffer());
kernel2.setArg<cl::Buffer>(1, out.getDeviceBuffer()); kernel2.setArg<cl::Buffer>(1, out2.getDeviceBuffer());
context.executeKernel(kernel2, gridSize, xthreads); context.executeKernel(kernel2, gridSize, xthreads);
kernel3.setArg<cl::Buffer>(0, out.getDeviceBuffer()); kernel3.setArg<cl::Buffer>(0, out2.getDeviceBuffer());
kernel3.setArg<cl::Buffer>(1, in.getDeviceBuffer()); kernel3.setArg<cl::Buffer>(1, in2.getDeviceBuffer());
context.executeKernel(kernel3, gridSize, ythreads); context.executeKernel(kernel3, gridSize, ythreads);
// Unpack the data. // Unpack the data.
unpackKernel.setArg<cl::Buffer>(0, in.getDeviceBuffer()); unpackKernel.setArg<cl::Buffer>(0, in2.getDeviceBuffer());
unpackKernel.setArg<cl::Buffer>(1, out.getDeviceBuffer()); unpackKernel.setArg<cl::Buffer>(1, out2.getDeviceBuffer());
context.executeKernel(unpackKernel, gridSize); context.executeKernel(unpackKernel, gridSize);
} }
else { else {
kernel1.setArg<cl::Buffer>(0, in.getDeviceBuffer()); kernel1.setArg<cl::Buffer>(0, in2.getDeviceBuffer());
kernel1.setArg<cl::Buffer>(1, out.getDeviceBuffer()); kernel1.setArg<cl::Buffer>(1, out2.getDeviceBuffer());
context.executeKernel(kernel1, xsize*ysize*zsize, zthreads); context.executeKernel(kernel1, xsize*ysize*zsize, zthreads);
kernel2.setArg<cl::Buffer>(0, out.getDeviceBuffer()); kernel2.setArg<cl::Buffer>(0, out2.getDeviceBuffer());
kernel2.setArg<cl::Buffer>(1, in.getDeviceBuffer()); kernel2.setArg<cl::Buffer>(1, in2.getDeviceBuffer());
context.executeKernel(kernel2, xsize*ysize*zsize, xthreads); context.executeKernel(kernel2, xsize*ysize*zsize, xthreads);
kernel3.setArg<cl::Buffer>(0, in.getDeviceBuffer()); kernel3.setArg<cl::Buffer>(0, in2.getDeviceBuffer());
kernel3.setArg<cl::Buffer>(1, out.getDeviceBuffer()); kernel3.setArg<cl::Buffer>(1, out2.getDeviceBuffer());
context.executeKernel(kernel3, xsize*ysize*zsize, ythreads); context.executeKernel(kernel3, xsize*ysize*zsize, ythreads);
} }
} }
......
...@@ -515,9 +515,9 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb ...@@ -515,9 +515,9 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
pmeEnergyBuffer.initialize(cl, cl.getNumThreadBlocks()*OpenCLContext::ThreadBlockSize, energyElementSize, "pmeEnergyBuffer"); pmeEnergyBuffer.initialize(cl, cl.getNumThreadBlocks()*OpenCLContext::ThreadBlockSize, energyElementSize, "pmeEnergyBuffer");
cl.clearBuffer(pmeEnergyBuffer); cl.clearBuffer(pmeEnergyBuffer);
sort = new OpenCLSort(cl, new SortTrait(), cl.getNumAtoms()); sort = new OpenCLSort(cl, new SortTrait(), cl.getNumAtoms());
fft = new OpenCLFFT3D(cl, gridSizeX, gridSizeY, gridSizeZ, true); fft = cl.createFFT(gridSizeX, gridSizeY, gridSizeZ, true);
if (doLJPME) if (doLJPME)
dispersionFft = new OpenCLFFT3D(cl, dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true); dispersionFft = cl.createFFT(dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true);
string vendor = cl.getDevice().getInfo<CL_DEVICE_VENDOR>(); string vendor = cl.getDevice().getInfo<CL_DEVICE_VENDOR>();
bool isNvidia = (vendor.size() >= 6 && vendor.substr(0, 6) == "NVIDIA"); bool isNvidia = (vendor.size() >= 6 && vendor.substr(0, 6) == "NVIDIA");
usePmeQueue = (!cl.getPlatformData().disablePmeStream && !cl.getPlatformData().useCpuPme && isNvidia); usePmeQueue = (!cl.getPlatformData().disablePmeStream && !cl.getPlatformData().useCpuPme && isNvidia);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. * * Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Authors: Peter Eastman, Mark Friedrichs * * Authors: Peter Eastman, Mark Friedrichs *
* Contributors: * * Contributors: *
* * * *
...@@ -210,10 +210,12 @@ private: ...@@ -210,10 +210,12 @@ private:
CommonCalcAmoebaMultipoleForceKernel::CommonCalcAmoebaMultipoleForceKernel(const std::string& name, const Platform& platform, ComputeContext& cc, const System& system) : CommonCalcAmoebaMultipoleForceKernel::CommonCalcAmoebaMultipoleForceKernel(const std::string& name, const Platform& platform, ComputeContext& cc, const System& system) :
CalcAmoebaMultipoleForceKernel(name, platform), cc(cc), system(system), hasInitializedScaleFactors(false), multipolesAreValid(false), hasCreatedEvent(false), CalcAmoebaMultipoleForceKernel(name, platform), cc(cc), system(system), hasInitializedScaleFactors(false), multipolesAreValid(false), hasCreatedEvent(false),
gkKernel(NULL) { fft(NULL), gkKernel(NULL) {
} }
CommonCalcAmoebaMultipoleForceKernel::~CommonCalcAmoebaMultipoleForceKernel() { CommonCalcAmoebaMultipoleForceKernel::~CommonCalcAmoebaMultipoleForceKernel() {
if (fft != NULL)
delete fft;
} }
void CommonCalcAmoebaMultipoleForceKernel::initialize(const System& system, const AmoebaMultipoleForce& force) { void CommonCalcAmoebaMultipoleForceKernel::initialize(const System& system, const AmoebaMultipoleForce& force) {
...@@ -718,6 +720,7 @@ void CommonCalcAmoebaMultipoleForceKernel::initialize(const System& system, cons ...@@ -718,6 +720,7 @@ void CommonCalcAmoebaMultipoleForceKernel::initialize(const System& system, cons
pmePhip.initialize(cc, 10*numMultipoles, elementSize, "pmePhip"); pmePhip.initialize(cc, 10*numMultipoles, elementSize, "pmePhip");
pmePhidp.initialize(cc, 20*numMultipoles, elementSize, "pmePhidp"); pmePhidp.initialize(cc, 20*numMultipoles, elementSize, "pmePhidp");
pmeCphi.initialize(cc, 10*numMultipoles, elementSize, "pmeCphi"); pmeCphi.initialize(cc, 10*numMultipoles, elementSize, "pmeCphi");
fft = cc.createFFT(gridSizeX, gridSizeY, gridSizeZ, false);
// Create the PME kernels. // Create the PME kernels.
...@@ -1162,9 +1165,9 @@ double CommonCalcAmoebaMultipoleForceKernel::execute(ContextImpl& context, bool ...@@ -1162,9 +1165,9 @@ double CommonCalcAmoebaMultipoleForceKernel::execute(ContextImpl& context, bool
pmeSpreadFixedMultipolesKernel->execute(cc.getNumAtoms()); pmeSpreadFixedMultipolesKernel->execute(cc.getNumAtoms());
if (useFixedPointChargeSpreading()) if (useFixedPointChargeSpreading())
pmeFinishSpreadChargeKernel->execute(pmeGrid1.getSize()); pmeFinishSpreadChargeKernel->execute(pmeGrid1.getSize());
computeFFT(true); fft->execFFT(pmeGrid1, pmeGrid2, true);
pmeConvolutionKernel->execute(gridSizeX*gridSizeY*gridSizeZ, 256); pmeConvolutionKernel->execute(gridSizeX*gridSizeY*gridSizeZ, 256);
computeFFT(false); fft->execFFT(pmeGrid2, pmeGrid1, false);
pmeFixedPotentialKernel->execute(cc.getNumAtoms()); pmeFixedPotentialKernel->execute(cc.getNumAtoms());
pmeTransformPotentialKernel->setArg(0, pmePhi); pmeTransformPotentialKernel->setArg(0, pmePhi);
pmeTransformPotentialKernel->execute(cc.getNumAtoms()); pmeTransformPotentialKernel->execute(cc.getNumAtoms());
...@@ -1186,9 +1189,9 @@ double CommonCalcAmoebaMultipoleForceKernel::execute(ContextImpl& context, bool ...@@ -1186,9 +1189,9 @@ double CommonCalcAmoebaMultipoleForceKernel::execute(ContextImpl& context, bool
pmeSpreadInducedDipolesKernel->execute(cc.getNumAtoms()); pmeSpreadInducedDipolesKernel->execute(cc.getNumAtoms());
if (useFixedPointChargeSpreading()) if (useFixedPointChargeSpreading())
pmeFinishSpreadChargeKernel->execute(pmeGrid1.getSize()); pmeFinishSpreadChargeKernel->execute(pmeGrid1.getSize());
computeFFT(true); fft->execFFT(pmeGrid1, pmeGrid2, true);
pmeConvolutionKernel->execute(gridSizeX*gridSizeY*gridSizeZ, 256); pmeConvolutionKernel->execute(gridSizeX*gridSizeY*gridSizeZ, 256);
computeFFT(false); fft->execFFT(pmeGrid2, pmeGrid1, false);
pmeInducedPotentialKernel->execute(cc.getNumAtoms()); pmeInducedPotentialKernel->execute(cc.getNumAtoms());
// Iterate until the dipoles converge. // Iterate until the dipoles converge.
...@@ -1262,9 +1265,9 @@ void CommonCalcAmoebaMultipoleForceKernel::computeInducedField() { ...@@ -1262,9 +1265,9 @@ void CommonCalcAmoebaMultipoleForceKernel::computeInducedField() {
pmeSpreadInducedDipolesKernel->execute(cc.getNumAtoms()); pmeSpreadInducedDipolesKernel->execute(cc.getNumAtoms());
if (useFixedPointChargeSpreading()) if (useFixedPointChargeSpreading())
pmeFinishSpreadChargeKernel->execute(pmeGrid1.getSize()); pmeFinishSpreadChargeKernel->execute(pmeGrid1.getSize());
computeFFT(true); fft->execFFT(pmeGrid1, pmeGrid2, true);
pmeConvolutionKernel->execute(gridSizeX*gridSizeY*gridSizeZ, 256); pmeConvolutionKernel->execute(gridSizeX*gridSizeY*gridSizeZ, 256);
computeFFT(false); fft->execFFT(pmeGrid2, pmeGrid1, false);
pmeInducedPotentialKernel->execute(cc.getNumAtoms()); pmeInducedPotentialKernel->execute(cc.getNumAtoms());
if (polarizationType == AmoebaMultipoleForce::Extrapolated) { if (polarizationType == AmoebaMultipoleForce::Extrapolated) {
pmeRecordInducedFieldDipolesKernel->execute(cc.getNumAtoms()); pmeRecordInducedFieldDipolesKernel->execute(cc.getNumAtoms());
...@@ -2399,7 +2402,14 @@ private: ...@@ -2399,7 +2402,14 @@ private:
}; };
CommonCalcHippoNonbondedForceKernel::CommonCalcHippoNonbondedForceKernel(const std::string& name, const Platform& platform, ComputeContext& cc, const System& system) : CommonCalcHippoNonbondedForceKernel::CommonCalcHippoNonbondedForceKernel(const std::string& name, const Platform& platform, ComputeContext& cc, const System& system) :
CalcHippoNonbondedForceKernel(name, platform), cc(cc), system(system), hasInitializedKernels(false), multipolesAreValid(false) { CalcHippoNonbondedForceKernel(name, platform), cc(cc), system(system), hasInitializedKernels(false), multipolesAreValid(false), fft(NULL), dfft(NULL) {
}
CommonCalcHippoNonbondedForceKernel::~CommonCalcHippoNonbondedForceKernel() {
if (fft != NULL)
delete fft;
if (dfft != NULL)
delete dfft;
} }
void CommonCalcHippoNonbondedForceKernel::initialize(const System& system, const HippoNonbondedForce& force) { void CommonCalcHippoNonbondedForceKernel::initialize(const System& system, const HippoNonbondedForce& force) {
...@@ -2665,6 +2675,8 @@ void CommonCalcHippoNonbondedForceKernel::initialize(const System& system, const ...@@ -2665,6 +2675,8 @@ void CommonCalcHippoNonbondedForceKernel::initialize(const System& system, const
pmePhidp.initialize(cc, 20*numParticles, elementSize, "pmePhidp"); pmePhidp.initialize(cc, 20*numParticles, elementSize, "pmePhidp");
pmeCphi.initialize(cc, 10*numParticles, elementSize, "pmeCphi"); pmeCphi.initialize(cc, 10*numParticles, elementSize, "pmeCphi");
pmeAtomGridIndex.initialize<mm_int2>(cc, numParticles, "pmeAtomGridIndex"); pmeAtomGridIndex.initialize<mm_int2>(cc, numParticles, "pmeAtomGridIndex");
fft = cc.createFFT(gridSizeX, gridSizeY, gridSizeZ, true);
dfft = cc.createFFT(dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true);
// Create the PME kernels. // Create the PME kernels.
...@@ -3298,9 +3310,9 @@ double CommonCalcHippoNonbondedForceKernel::execute(ContextImpl& context, bool i ...@@ -3298,9 +3310,9 @@ double CommonCalcHippoNonbondedForceKernel::execute(ContextImpl& context, bool i
pmeSpreadFixedMultipolesKernel->execute(cc.getNumAtoms()); pmeSpreadFixedMultipolesKernel->execute(cc.getNumAtoms());
if (useFixedPointChargeSpreading()) if (useFixedPointChargeSpreading())
pmeFinishSpreadChargeKernel->execute(pmeGrid1.getSize()); pmeFinishSpreadChargeKernel->execute(pmeGrid1.getSize());
computeFFT(true, false); fft->execFFT(pmeGrid1, pmeGrid2, true);
pmeConvolutionKernel->execute(gridSizeX*gridSizeY*gridSizeZ, 256); pmeConvolutionKernel->execute(gridSizeX*gridSizeY*gridSizeZ, 256);
computeFFT(false, false); fft->execFFT(pmeGrid2, pmeGrid1, false);
pmeFixedPotentialKernel->execute(cc.getNumAtoms()); pmeFixedPotentialKernel->execute(cc.getNumAtoms());
pmeTransformPotentialKernel->setArg(0, pmePhi); pmeTransformPotentialKernel->setArg(0, pmePhi);
pmeTransformPotentialKernel->execute(cc.getNumAtoms()); pmeTransformPotentialKernel->execute(cc.getNumAtoms());
...@@ -3317,11 +3329,11 @@ double CommonCalcHippoNonbondedForceKernel::execute(ContextImpl& context, bool i ...@@ -3317,11 +3329,11 @@ double CommonCalcHippoNonbondedForceKernel::execute(ContextImpl& context, bool i
dpmeSpreadChargeKernel->execute(PmeOrder*cc.getNumAtoms(), 128); dpmeSpreadChargeKernel->execute(PmeOrder*cc.getNumAtoms(), 128);
if (useFixedPointChargeSpreading()) if (useFixedPointChargeSpreading())
dpmeFinishSpreadChargeKernel->execute(dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256); dpmeFinishSpreadChargeKernel->execute(dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256);
computeFFT(true, true); dfft->execFFT(pmeGrid1, pmeGrid2, true);
if (includeEnergy) if (includeEnergy)
dpmeEvalEnergyKernel->execute(dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ); dpmeEvalEnergyKernel->execute(dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ);
dpmeConvolutionKernel->execute(dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256); dpmeConvolutionKernel->execute(dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256);
computeFFT(false, true); dfft->execFFT(pmeGrid2, pmeGrid1, false);
dpmeInterpolateForceKernel->execute(cc.getNumAtoms(), 128); dpmeInterpolateForceKernel->execute(cc.getNumAtoms(), 128);
} }
...@@ -3388,9 +3400,9 @@ void CommonCalcHippoNonbondedForceKernel::computeInducedField(int optOrder) { ...@@ -3388,9 +3400,9 @@ void CommonCalcHippoNonbondedForceKernel::computeInducedField(int optOrder) {
pmeSpreadInducedDipolesKernel->execute(cc.getNumAtoms()); pmeSpreadInducedDipolesKernel->execute(cc.getNumAtoms());
if (useFixedPointChargeSpreading()) if (useFixedPointChargeSpreading())
pmeFinishSpreadChargeKernel->execute(pmeGrid1.getSize()); pmeFinishSpreadChargeKernel->execute(pmeGrid1.getSize());
computeFFT(true, false); fft->execFFT(pmeGrid1, pmeGrid2, true);
pmeConvolutionKernel->execute(gridSizeX*gridSizeY*gridSizeZ, 256); pmeConvolutionKernel->execute(gridSizeX*gridSizeY*gridSizeZ, 256);
computeFFT(false, false); fft->execFFT(pmeGrid2, pmeGrid1, false);
pmeInducedPotentialKernel->setArg(2, optOrder); pmeInducedPotentialKernel->setArg(2, optOrder);
pmeInducedPotentialKernel->execute(cc.getNumAtoms()); pmeInducedPotentialKernel->execute(cc.getNumAtoms());
pmeRecordInducedFieldDipolesKernel->execute(cc.getNumAtoms()); pmeRecordInducedFieldDipolesKernel->execute(cc.getNumAtoms());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment