Unverified Commit f717ed89 authored by Anton Gorenko's avatar Anton Gorenko
Browse files

Use VkFFT in HipFFT3D, remove hipFFT and the builtin FFT

* VkFFT-based 3D FFT;
* Caching of compiled VkFFT kernels;
* Extend FFT tests with more sizes.
parent b9c45d45
......@@ -177,3 +177,31 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
9. VkFFT
OpenMM uses the VkFFT library by Dmitrii Tolmachev. It may be used under the
terms of the MIT License:
MIT License
Copyright (c) 2020 - present Dmitrii Tolmachev
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
......@@ -13,7 +13,6 @@
#----------------------------------------------------
FIND_PACKAGE(HIPRTC CONFIG)
FIND_PACKAGE(HIPFFT CONFIG QUIET)
SET(OPENMM_BUILD_HIP_TESTS TRUE CACHE BOOL "Whether to build HIP test cases")
IF(BUILD_TESTING AND OPENMM_BUILD_HIP_TESTS)
......@@ -104,10 +103,6 @@ IF(OPENMM_BUILD_SHARED_LIB)
TARGET_LINK_LIBRARIES(${SHARED_TARGET} PUBLIC ${OPENMM_LIBRARY_NAME} ${PTHREADS_LIB} hip::host hiprtc::hiprtc)
SET_TARGET_PROPERTIES(${SHARED_TARGET} PROPERTIES COMPILE_FLAGS "${EXTRA_COMPILE_FLAGS} -DOPENMM_COMMON_BUILDING_SHARED_LIBRARY")
SET_TARGET_PROPERTIES(${SHARED_TARGET} PROPERTIES LINK_FLAGS "${EXTRA_LINK_FLAGS}")
IF(HIPFFT_FOUND)
TARGET_LINK_LIBRARIES(${SHARED_TARGET} PUBLIC hip::hipfft)
TARGET_COMPILE_OPTIONS(${SHARED_TARGET} PUBLIC "-DOPENMM_HIP_WITH_HIPFFT")
ENDIF(HIPFFT_FOUND)
INSTALL_TARGETS(/lib/plugins RUNTIME_DIRECTORY /lib/plugins ${SHARED_TARGET})
ENDIF(OPENMM_BUILD_SHARED_LIB)
......@@ -121,10 +116,6 @@ IF(OPENMM_BUILD_STATIC_LIB)
TARGET_LINK_LIBRARIES(${STATIC_TARGET} ${OPENMM_LIBRARY_NAME} ${PTHREADS_LIB_STATIC} hip::host hiprtc::hiprtc)
SET_TARGET_PROPERTIES(${STATIC_TARGET} PROPERTIES COMPILE_FLAGS "${EXTRA_COMPILE_FLAGS} -DOPENMM_COMMON_BUILDING_STATIC_LIBRARY")
SET_TARGET_PROPERTIES(${STATIC_TARGET} PROPERTIES LINK_FLAGS "${EXTRA_LINK_FLAGS}")
IF(HIPFFT_FOUND)
TARGET_LINK_LIBRARIES(${STATIC_TARGET} PUBLIC hip::hipfft)
TARGET_COMPILE_OPTIONS(${STATIC_TARGET} PUBLIC "-DOPENMM_HIP_WITH_HIPFFT")
ENDIF(HIPFFT_FOUND)
INSTALL_TARGETS(/lib/plugins RUNTIME_DIRECTORY /lib/plugins ${STATIC_TARGET})
ENDIF(OPENMM_BUILD_STATIC_LIB)
......
......@@ -55,6 +55,7 @@
#include "HipIntegrationUtilities.h"
#include "HipNonbondedUtilities.h"
#include "HipPlatform.h"
#include "HipFFT3D.h"
#include "openmm/OpenMMException.h"
#include "openmm/common/ComputeContext.h"
#include "openmm/Kernel.h"
......@@ -161,6 +162,22 @@ public:
* Construct a ComputeEvent object of the appropriate class for this platform.
*/
ComputeEvent createEvent();
/**
* Create a new HipFFT3D.
*
* @param xsize the first dimension of the data sets on which FFTs will be performed
* @param ysize the second dimension of the data sets on which FFTs will be performed
* @param zsize the third dimension of the data sets on which FFTs will be performed
* @param realToComplex if true, a real-to-complex transform will be done. Otherwise, it is complex-to-complex.
* @param stream HIP stream
* @param in the data to transform, ordered such that in[x*ysize*zsize + y*zsize + z] contains element (x, y, z)
* @param out on exit, this contains the transformed data
*/
HipFFT3D* createFFT(int xsize, int ysize, int zsize, bool realToComplex, hipStream_t stream, HipArray& in, HipArray& out);
/**
* Get the smallest legal size for a dimension of the grid supported by the FFT.
*/
virtual int findLegalFFTDimension(int minimum);
/**
* Compile source code to create a ComputeProgram.
*
......
......@@ -10,8 +10,8 @@
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Portions copyright (c) 2020 Advanced Micro Devices, Inc. *
* Authors: Peter Eastman, Nicholas Curtis *
* Portions copyright (c) 2021 Advanced Micro Devices, Inc. *
* Authors: *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
......@@ -30,23 +30,16 @@
#include "HipArray.h"
#define VKFFT_BACKEND 2 // HIP
#include "vkFFT.h"
namespace OpenMM {
class HipContext;
/**
* This class performs three dimensional Fast Fourier Transforms. It is based on the
* mixed radix algorithm described in
* <p>
* Takahashi, D. and Kanada, Y., "High-Performance Radix-2, 3 and 5 Parallel 1-D Complex
* FFT Algorithms for Distributed-Memory Parallel Computers." Journal of Supercomputing,
* 15, 207–228 (2000).
* <p>
* This class places certain restrictions on the allowed dimensions of the grid. First,
* the size of each dimension may have no prime factors other than 2, 3, 5, and 7. You
* can call findLegalDimension() to determine the smallest size that satisfies this
* requirement and is greater than or equal to a specified minimum size. Second, the size
* of each dimension must be small enough to compute each 1D transform entirely in local
* memory with one work unit per data point. This will vary between platforms, but is
* typically at least 512.
* This class performs three dimensional Fast Fourier Transforms using VkFFT by
* Dmitrii Tolmachev (https://github.com/DTolm/VkFFT).
* <p>
* Note that this class performs an unnormalized transform. That means that if you perform
* a forward transform followed immediately by an inverse transform, the effect is to
......@@ -57,44 +50,49 @@ class OPENMM_EXPORT_COMMON HipFFT3D {
public:
/**
* Create an HipFFT3D object for performing transforms of a particular size.
* <p>
* The transform cannot be done in-place: the input and output
* arrays must be different. Also, the input array is used as workspace, so its contents
* are destroyed. This also means that both arrays must be large enough to hold complex values,
* even when performing a real-to-complex transform.
* <p>
* When performing a real-to-complex transform, the output data is of size xsize*ysize*(zsize/2+1)
* and contains only the non-redundant elements.
*
* @param context the context in which to perform calculations
* @param xsize the first dimension of the data sets on which FFTs will be performed
* @param ysize the second dimension of the data sets on which FFTs will be performed
* @param zsize the third dimension of the data sets on which FFTs will be performed
* @param realToComplex if true, a real-to-complex transform will be done. Otherwise, it is complex-to-complex.
* @param stream HIP stream
* @param in the data to transform, ordered such that in[x*ysize*zsize + y*zsize + z] contains element (x, y, z)
* @param out on exit, this contains the transformed data
*/
HipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool realToComplex=false);
HipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool realToComplex, hipStream_t stream, HipArray& in, HipArray& out);
~HipFFT3D();
/**
* Perform a Fourier transform. The transform cannot be done in-place: the input and output
* arrays must be different. Also, the input array is used as workspace, so its contents
* are destroyed. This also means that both arrays must be large enough to hold complex values,
* even when performing a real-to-complex transform.
* <p>
* When performing a real-to-complex transform, the output data is of size xsize*ysize*(zsize/2+1)
* and contains only the non-redundant elements.
* Perform a Fourier transform.
*
* @param in the data to transform, ordered such that in[x*ysize*zsize + y*zsize + z] contains element (x, y, z)
* @param out on exit, this contains the transformed data
* @param forward true to perform a forward transform, false to perform an inverse transform
*/
void execFFT(HipArray& in, HipArray& out, bool forward = true);
void execFFT(bool forward);
/**
* Get the smallest legal size for a dimension of the grid (that is, a size with no prime
* factors other than 2, 3, 5, and 7).
* factors other than 2, 3, 5, 7, 11, 13). VkFFT supports arbitrary sizes but they may work
* slower.
*
* @param minimum the minimum size the return value must be greater than or equal to
*/
static int findLegalDimension(int minimum);
private:
hipFunction_t createKernel(int xsize, int ysize, int zsize, int& threads, int axis, bool forward, bool inputIsReal);
int xsize, ysize, zsize;
int xthreads, ythreads, zthreads;
bool packRealAsComplex;
hipStream_t stream;
HipContext& context;
hipFunction_t xkernel, ykernel, zkernel;
hipFunction_t invxkernel, invykernel, invzkernel;
hipFunction_t packForwardKernel, unpackForwardKernel, packBackwardKernel, unpackBackwardKernel;
int deviceIndex;
void* inputBuffer;
void* outputBuffer;
uint64_t inputBufferSize;
uint64_t outputBufferSize;
VkFFTApplication* app;
};
} // namespace OpenMM
......
......@@ -36,7 +36,6 @@
#include "openmm/kernels.h"
#include "openmm/System.h"
#include "openmm/common/CommonKernels.h"
#include <hipfft.h>
namespace OpenMM {
......@@ -279,11 +278,7 @@ private:
hipStream_t pmeStream;
hipEvent_t pmeSyncEvent, paramsSyncEvent;
HipFFT3D* fft;
hipfftHandle fftForward;
hipfftHandle fftBackward;
HipFFT3D* dispersionFft;
hipfftHandle dispersionFftForward;
hipfftHandle dispersionFftBackward;
hipFunction_t computeParamsKernel, computeExclusionParamsKernel;
hipFunction_t ewaldSumsKernel;
hipFunction_t ewaldForcesKernel;
......@@ -306,7 +301,7 @@ private:
int interpolateForceThreads;
int gridSizeX, gridSizeY, gridSizeZ;
int dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ;
bool hasCoulomb, hasLJ, usePmeStream, useHipFFT, doLJPME, usePosqCharges, recomputeParams, hasOffsets;
bool hasCoulomb, hasLJ, usePmeStream, doLJPME, usePosqCharges, recomputeParams, hasOffsets;
NonbondedMethod nonbondedMethod;
static const int PmeOrder = 5;
};
......
......@@ -37,6 +37,7 @@
#include "HipKernelSources.h"
#include "HipNonbondedUtilities.h"
#include "HipProgram.h"
#include "HipFFT3D.h"
#include "openmm/common/ComputeArray.h"
#include "SHA1.h"
#include "openmm/Platform.h"
......@@ -85,7 +86,7 @@ bool HipContext::hasInitializedHip = false;
HipContext::HipContext(const System& system, int deviceIndex, bool useBlockingSync, const string& precision, const string& tempDir, HipPlatform::PlatformData& platformData,
HipContext* originalContext) : ComputeContext(system), currentStream(0), platformData(platformData), contextIsValid(false), hasAssignedPosqCharges(false),
pinnedBuffer(NULL), integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL),
supportsHardwareFloatGlobalAtomicAdd(false) {
useBlockingSync(useBlockingSync), supportsHardwareFloatGlobalAtomicAdd(false) {
if (!hasInitializedHip) {
CHECK_RESULT2(hipInit(0), "Error initializing HIP");
hasInitializedHip = true;
......@@ -649,6 +650,14 @@ ComputeEvent HipContext::createEvent() {
return shared_ptr<ComputeEventImpl>(new HipEvent(*this));
}
HipFFT3D* HipContext::createFFT(int xsize, int ysize, int zsize, bool realToComplex, hipStream_t stream, HipArray& in, HipArray& out) {
return new HipFFT3D(*this, xsize, ysize, zsize, realToComplex, stream, in, out);
}
int HipContext::findLegalFFTDimension(int minimum) {
return HipFFT3D::findLegalDimension(minimum);
}
ComputeProgram HipContext::compileProgram(const std::string source, const std::map<std::string, std::string>& defines) {
hipModule_t module = createModule(HipKernelSources::vectorOps+source, defines);
return shared_ptr<ComputeProgramImpl>(new HipProgram(*this, module));
......
This diff is collapsed.
......@@ -542,14 +542,6 @@ HipCalcNonbondedForceKernel::~HipCalcNonbondedForceKernel() {
if (pmeio != NULL)
delete pmeio;
if (hasInitializedFFT) {
if (useHipFFT) {
hipfftDestroy(fftForward);
hipfftDestroy(fftBackward);
if (doLJPME) {
hipfftDestroy(dispersionFftForward);
hipfftDestroy(dispersionFftBackward);
}
}
if (usePmeStream) {
hipStreamDestroy(pmeStream);
hipEventDestroy(pmeSyncEvent);
......@@ -696,15 +688,15 @@ void HipCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
// Compute the PME parameters.
NonbondedForceImpl::calcPMEParameters(system, force, alpha, gridSizeX, gridSizeY, gridSizeZ, false);
gridSizeX = HipFFT3D::findLegalDimension(gridSizeX);
gridSizeY = HipFFT3D::findLegalDimension(gridSizeY);
gridSizeZ = HipFFT3D::findLegalDimension(gridSizeZ);
gridSizeX = cu.findLegalFFTDimension(gridSizeX);
gridSizeY = cu.findLegalFFTDimension(gridSizeY);
gridSizeZ = cu.findLegalFFTDimension(gridSizeZ);
if (doLJPME) {
NonbondedForceImpl::calcPMEParameters(system, force, dispersionAlpha, dispersionGridSizeX,
dispersionGridSizeY, dispersionGridSizeZ, true);
dispersionGridSizeX = HipFFT3D::findLegalDimension(dispersionGridSizeX);
dispersionGridSizeY = HipFFT3D::findLegalDimension(dispersionGridSizeY);
dispersionGridSizeZ = HipFFT3D::findLegalDimension(dispersionGridSizeZ);
dispersionGridSizeX = cu.findLegalFFTDimension(dispersionGridSizeX);
dispersionGridSizeY = cu.findLegalFFTDimension(dispersionGridSizeY);
dispersionGridSizeZ = cu.findLegalFFTDimension(dispersionGridSizeZ);
}
defines["EWALD_ALPHA"] = cu.doubleToString(alpha);
......@@ -724,9 +716,7 @@ void HipCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
for (int i = 0; i < numParticles; i++)
ewaldSelfEnergy += baseParticleParamVec[i].z*pow(baseParticleParamVec[i].y*dispersionAlpha, 6)/3.0;
}
char deviceName[100];
hipDeviceGetName(deviceName, 100, cu.getDevice());
usePmeStream = (!cu.getPlatformData().disablePmeStream && !cu.getPlatformData().useCpuPme && string(deviceName) != "GeForce GTX 980"); // Using a separate stream is slower on GTX 980
usePmeStream = (!cu.getPlatformData().disablePmeStream && !cu.getPlatformData().useCpuPme);
map<string, string> pmeDefines;
pmeDefines["PME_ORDER"] = cu.intToString(PmeOrder);
pmeDefines["NUM_ATOMS"] = cu.intToString(numParticles);
......@@ -818,45 +808,11 @@ void HipCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
pmeEnergyBuffer.initialize(cu, cu.getNumThreadBlocks()*HipContext::ThreadBlockSize, energyElementSize, "pmeEnergyBuffer");
cu.clearBuffer(pmeEnergyBuffer);
sort = new HipSort(cu, new SortTrait(), cu.getNumAtoms());
int cufftVersion;
hipfftGetVersion(&cufftVersion);
useHipFFT = (cufftVersion >= 7050); // There was a critical bug in version 7.0
if (useHipFFT) {
hipfftResult result = hipfftPlan3d(&fftForward, gridSizeX, gridSizeY, gridSizeZ, cu.getUseDoublePrecision() ? HIPFFT_D2Z : HIPFFT_R2C);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error initializing FFT: "+cu.intToString(result));
result = hipfftPlan3d(&fftBackward, gridSizeX, gridSizeY, gridSizeZ, cu.getUseDoublePrecision() ? HIPFFT_Z2D : HIPFFT_C2R);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error initializing FFT: "+cu.intToString(result));
if (doLJPME) {
result = hipfftPlan3d(&dispersionFftForward, dispersionGridSizeX, dispersionGridSizeY,
dispersionGridSizeZ, cu.getUseDoublePrecision() ? HIPFFT_D2Z : HIPFFT_R2C);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error initializing disperison FFT: "+cu.intToString(result));
result = hipfftPlan3d(&dispersionFftBackward, dispersionGridSizeX, dispersionGridSizeY,
dispersionGridSizeZ, cu.getUseDoublePrecision() ? HIPFFT_Z2D : HIPFFT_C2R);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error initializing disperison FFT: "+cu.intToString(result));
}
}
else {
fft = new HipFFT3D(cu, gridSizeX, gridSizeY, gridSizeZ, true);
if (doLJPME)
dispersionFft = new HipFFT3D(cu, dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true);
}
// Prepare for doing PME on its own stream.
if (usePmeStream) {
hipStreamCreateWithFlags(&pmeStream, hipStreamNonBlocking);
if (useHipFFT) {
hipfftSetStream(fftForward, pmeStream);
hipfftSetStream(fftBackward, pmeStream);
if (doLJPME) {
hipfftSetStream(dispersionFftForward, pmeStream);
hipfftSetStream(dispersionFftBackward, pmeStream);
}
}
CHECK_RESULT(hipStreamCreateWithFlags(&pmeStream, hipStreamNonBlocking), "Error creating stream for NonbondedForce");
CHECK_RESULT(hipEventCreateWithFlags(&pmeSyncEvent, hipEventDisableTiming), "Error creating event for NonbondedForce");
CHECK_RESULT(hipEventCreateWithFlags(&paramsSyncEvent, hipEventDisableTiming), "Error creating event for NonbondedForce");
int recipForceGroup = force.getReciprocalSpaceForceGroup();
......@@ -865,6 +821,11 @@ void HipCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
cu.addPreComputation(new SyncStreamPreComputation(cu, pmeStream, pmeSyncEvent, recipForceGroup));
cu.addPostComputation(new SyncStreamPostComputation(cu, pmeSyncEvent, cu.getKernel(module, "addEnergy"), pmeEnergyBuffer, recipForceGroup));
}
hipStream_t fftStream = usePmeStream ? pmeStream : cu.getCurrentStream();
fft = cu.createFFT(gridSizeX, gridSizeY, gridSizeZ, true, fftStream, pmeGrid1, pmeGrid2);
if (doLJPME)
dispersionFft = cu.createFFT(dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true, fftStream, pmeGrid1, pmeGrid2);
hasInitializedFFT = true;
// Initialize the b-spline moduli.
......@@ -1215,20 +1176,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
void* finishSpreadArgs[] = {&pmeGrid2.getDevicePointer(), &pmeGrid1.getDevicePointer()};
cu.executeKernelFlat(pmeFinishSpreadChargeKernel, finishSpreadArgs, gridSizeX*gridSizeY*gridSizeZ, 256);
if (useHipFFT) {
if (cu.getUseDoublePrecision()) {
hipfftResult result = hipfftExecD2Z(fftForward, (double*) pmeGrid1.getDevicePointer(), (double2*) pmeGrid2.getDevicePointer());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
} else {
hipfftResult result = hipfftExecR2C(fftForward, (float*) pmeGrid1.getDevicePointer(), (float2*) pmeGrid2.getDevicePointer());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
}
}
else {
fft->execFFT(pmeGrid1, pmeGrid2, true);
}
fft->execFFT(true);
if (includeEnergy) {
void* computeEnergyArgs[] = {&pmeGrid2.getDevicePointer(), usePmeStream ? &pmeEnergyBuffer.getDevicePointer() : &cu.getEnergyBuffer().getDevicePointer(),
......@@ -1242,20 +1190,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
recipBoxVectorPointer[0], recipBoxVectorPointer[1], recipBoxVectorPointer[2]};
cu.executeKernel(pmeConvolutionKernel, convolutionArgs, gridSizeX*gridSizeY*gridSizeZ, 256);
if (useHipFFT) {
if (cu.getUseDoublePrecision()) {
hipfftResult result = hipfftExecZ2D(fftBackward, (double2*) pmeGrid2.getDevicePointer(), (double*) pmeGrid1.getDevicePointer());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
} else {
hipfftResult result = hipfftExecC2R(fftBackward, (float2*) pmeGrid2.getDevicePointer(), (float*) pmeGrid1.getDevicePointer());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
}
}
else {
fft->execFFT(pmeGrid2, pmeGrid1, false);
}
fft->execFFT(false);
void* interpolateArgs[] = {&cu.getPosq().getDevicePointer(), &cu.getForce().getDevicePointer(), &pmeGrid1.getDevicePointer(), cu.getPeriodicBoxSizePointer(),
cu.getInvPeriodicBoxSizePointer(), cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer(),
......@@ -1285,20 +1220,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
void* finishSpreadArgs[] = {&pmeGrid2.getDevicePointer(), &pmeGrid1.getDevicePointer()};
cu.executeKernelFlat(pmeDispersionFinishSpreadChargeKernel, finishSpreadArgs, dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256);
if (useHipFFT) {
if (cu.getUseDoublePrecision()) {
hipfftResult result = hipfftExecD2Z(dispersionFftForward, (double*) pmeGrid1.getDevicePointer(), (double2*) pmeGrid2.getDevicePointer());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
} else {
hipfftResult result = hipfftExecR2C(dispersionFftForward, (float*) pmeGrid1.getDevicePointer(), (float2*) pmeGrid2.getDevicePointer());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
}
}
else {
dispersionFft->execFFT(pmeGrid1, pmeGrid2, true);
}
dispersionFft->execFFT(true);
if (includeEnergy) {
void* computeEnergyArgs[] = {&pmeGrid2.getDevicePointer(), usePmeStream ? &pmeEnergyBuffer.getDevicePointer() : &cu.getEnergyBuffer().getDevicePointer(),
......@@ -1312,20 +1234,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
recipBoxVectorPointer[0], recipBoxVectorPointer[1], recipBoxVectorPointer[2]};
cu.executeKernel(pmeDispersionConvolutionKernel, convolutionArgs, dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, 256);
if (useHipFFT) {
if (cu.getUseDoublePrecision()) {
hipfftResult result = hipfftExecZ2D(dispersionFftBackward, (double2*) pmeGrid2.getDevicePointer(), (double*) pmeGrid1.getDevicePointer());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
} else {
hipfftResult result = hipfftExecC2R(dispersionFftBackward, (float2*) pmeGrid2.getDevicePointer(), (float*) pmeGrid1.getDevicePointer());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cu.intToString(result));
}
}
else {
dispersionFft->execFFT(pmeGrid2, pmeGrid1, false);
}
dispersionFft->execFFT(false);
void* interpolateArgs[] = {&cu.getPosq().getDevicePointer(), &cu.getForce().getDevicePointer(), &pmeGrid1.getDevicePointer(), cu.getPeriodicBoxSizePointer(),
cu.getInvPeriodicBoxSizePointer(), cu.getPeriodicBoxVecXPointer(), cu.getPeriodicBoxVecYPointer(), cu.getPeriodicBoxVecZPointer(),
......
static __inline__ __device__ real2 multiplyComplex(real2 c1, real2 c2) {
return make_real2(c1.x*c2.x-c1.y*c2.y, c1.x*c2.y+c1.y*c2.x);
}
/**
* Load a value from the half-complex grid produces by a real-to-complex transform.
*/
static __inline__ __device__ real2 loadComplexValue(const real2* __restrict__ in, int x, int y, int z) {
const int inputZSize = ZSIZE/2+1;
if (z < inputZSize)
return in[x*YSIZE*inputZSize+y*inputZSize+z];
int xp = (x == 0 ? 0 : XSIZE-x);
int yp = (y == 0 ? 0 : YSIZE-y);
real2 value = in[xp*YSIZE*inputZSize+yp*inputZSize+(ZSIZE-z)];
return make_real2(value.x, -value.y);
}
/**
* Perform a 1D FFT on each row along one axis.
*/
extern "C" __global__ void execFFT(const INPUT_TYPE* __restrict__ in, OUTPUT_TYPE* __restrict__ out) {
__shared__ real2 w[ZSIZE];
__shared__ real2 data0[BLOCKS_PER_GROUP*ZSIZE];
__shared__ real2 data1[BLOCKS_PER_GROUP*ZSIZE];
for (int i = threadIdx.x; i < ZSIZE; i += blockDim.x)
w[i] = make_real2(cos(-(SIGN)*i*2*M_PI/ZSIZE), sin(-(SIGN)*i*2*M_PI/ZSIZE));
__syncthreads();
const int block = threadIdx.x/THREADS_PER_BLOCK;
for (int baseIndex = blockIdx.x*BLOCKS_PER_GROUP; baseIndex < XSIZE*YSIZE; baseIndex += gridDim.x*BLOCKS_PER_GROUP) {
int index = baseIndex+block;
int x = index/YSIZE;
int y = index-x*YSIZE;
#if OUTPUT_IS_PACKED
if (x < XSIZE/2+1) {
#endif
if (index < XSIZE*YSIZE)
for (int i = threadIdx.x-block*THREADS_PER_BLOCK; i < ZSIZE; i += THREADS_PER_BLOCK)
#if INPUT_IS_REAL
data0[i+block*ZSIZE] = make_real2(in[x*(YSIZE*ZSIZE)+y*ZSIZE+i], 0);
#elif INPUT_IS_PACKED
data0[i+block*ZSIZE] = loadComplexValue(in, x, y, i);
#else
data0[i+block*ZSIZE] = in[x*(YSIZE*ZSIZE)+y*ZSIZE+i];
#endif
#if OUTPUT_IS_PACKED
}
#endif
__syncthreads();
COMPUTE_FFT
}
}
/**
* Combine the two halves of a real grid into a complex grid that is half as large.
*/
extern "C" __global__ void packForwardData(const real* __restrict__ in, real2* __restrict__ out) {
const int gridSize = PACKED_XSIZE*PACKED_YSIZE*PACKED_ZSIZE;
for (int index = blockIdx.x*blockDim.x+threadIdx.x; index < gridSize; index += blockDim.x*gridDim.x) {
int x = index/(PACKED_YSIZE*PACKED_ZSIZE);
int remainder = index-x*(PACKED_YSIZE*PACKED_ZSIZE);
int y = remainder/PACKED_ZSIZE;
int z = remainder-y*PACKED_ZSIZE;
#if PACKED_AXIS == 0
real2 value = make_real2(in[2*x*YSIZE*ZSIZE+y*ZSIZE+z], in[(2*x+1)*YSIZE*ZSIZE+y*ZSIZE+z]);
#elif PACKED_AXIS == 1
real2 value = make_real2(in[x*YSIZE*ZSIZE+2*y*ZSIZE+z], in[x*YSIZE*ZSIZE+(2*y+1)*ZSIZE+z]);
#else
real2 value = make_real2(in[x*YSIZE*ZSIZE+y*ZSIZE+2*z], in[x*YSIZE*ZSIZE+y*ZSIZE+(2*z+1)]);
#endif
out[index] = value;
}
}
/**
* Split the transformed data back into a full sized, symmetric grid.
*/
extern "C" __global__ void unpackForwardData(const real2* __restrict__ in, real2* __restrict__ out) {
// Compute the phase factors.
#if PACKED_AXIS == 0
__shared__ real2 w[PACKED_XSIZE];
for (int i = threadIdx.x; i < PACKED_XSIZE; i += blockDim.x)
w[i] = make_real2(sin(i*2*M_PI/XSIZE), cos(i*2*M_PI/XSIZE));
#elif PACKED_AXIS == 1
__shared__ real2 w[PACKED_YSIZE];
for (int i = threadIdx.x; i < PACKED_YSIZE; i += blockDim.x)
w[i] = make_real2(sin(i*2*M_PI/YSIZE), cos(i*2*M_PI/YSIZE));
#else
__shared__ real2 w[PACKED_ZSIZE];
for (int i = threadIdx.x; i < PACKED_ZSIZE; i += blockDim.x)
w[i] = make_real2(sin(i*2*M_PI/ZSIZE), cos(i*2*M_PI/ZSIZE));
#endif
__syncthreads();
// Transform the data.
const int gridSize = PACKED_XSIZE*PACKED_YSIZE*PACKED_ZSIZE;
const int outputZSize = ZSIZE/2+1;
for (int index = blockIdx.x*blockDim.x+threadIdx.x; index < gridSize; index += blockDim.x*gridDim.x) {
int x = index/(PACKED_YSIZE*PACKED_ZSIZE);
int remainder = index-x*(PACKED_YSIZE*PACKED_ZSIZE);
int y = remainder/PACKED_ZSIZE;
int z = remainder-y*PACKED_ZSIZE;
int xp = (x == 0 ? 0 : PACKED_XSIZE-x);
int yp = (y == 0 ? 0 : PACKED_YSIZE-y);
int zp = (z == 0 ? 0 : PACKED_ZSIZE-z);
real2 z1 = in[x*PACKED_YSIZE*PACKED_ZSIZE+y*PACKED_ZSIZE+z];
real2 z2 = in[xp*PACKED_YSIZE*PACKED_ZSIZE+yp*PACKED_ZSIZE+zp];
#if PACKED_AXIS == 0
real2 wfac = w[x];
#elif PACKED_AXIS == 1
real2 wfac = w[y];
#else
real2 wfac = w[z];
#endif
real2 output = make_real2((z1.x+z2.x - wfac.x*(z1.x-z2.x) + wfac.y*(z1.y+z2.y))/2, (z1.y-z2.y - wfac.y*(z1.x-z2.x) - wfac.x*(z1.y+z2.y))/2);
if (z < outputZSize)
out[x*YSIZE*outputZSize+y*outputZSize+z] = output;
xp = (x == 0 ? 0 : XSIZE-x);
yp = (y == 0 ? 0 : YSIZE-y);
zp = (z == 0 ? 0 : ZSIZE-z);
if (zp < outputZSize) {
#if PACKED_AXIS == 0
if (x == 0)
out[PACKED_XSIZE*YSIZE*outputZSize+yp*outputZSize+zp] = make_real2((z1.x-z1.y+z2.x-z2.y)/2, (-z1.x-z1.y+z2.x+z2.y)/2);
#elif PACKED_AXIS == 1
if (y == 0)
out[xp*YSIZE*outputZSize+PACKED_YSIZE*outputZSize+zp] = make_real2((z1.x-z1.y+z2.x-z2.y)/2, (-z1.x-z1.y+z2.x+z2.y)/2);
#else
if (z == 0)
out[xp*YSIZE*outputZSize+yp*outputZSize+PACKED_ZSIZE] = make_real2((z1.x-z1.y+z2.x-z2.y)/2, (-z1.x-z1.y+z2.x+z2.y)/2);
#endif
else
out[xp*YSIZE*outputZSize+yp*outputZSize+zp] = make_real2(output.x, -output.y);
}
}
}
/**
* Load a value from the half-complex grid produced by a real-to-complex transform.
*/
static __inline__ __device__ real2 loadComplexValue(const real2* __restrict__ in, int x, int y, int z) {
const int inputZSize = ZSIZE/2+1;
if (z < inputZSize)
return in[x*YSIZE*inputZSize+y*inputZSize+z];
int xp = (x == 0 ? 0 : XSIZE-x);
int yp = (y == 0 ? 0 : YSIZE-y);
real2 value = in[xp*YSIZE*inputZSize+yp*inputZSize+(ZSIZE-z)];
return make_real2(value.x, -value.y);
}
/**
* Repack the symmetric complex grid into one half as large in preparation for doing an inverse complex-to-real transform.
*/
extern "C" __global__ void packBackwardData(const real2* __restrict__ in, real2* __restrict__ out) {
// Compute the phase factors.
#if PACKED_AXIS == 0
__shared__ real2 w[PACKED_XSIZE];
for (int i = threadIdx.x; i < PACKED_XSIZE; i += blockDim.x)
w[i] = make_real2(cos(i*2*M_PI/XSIZE), sin(i*2*M_PI/XSIZE));
#elif PACKED_AXIS == 1
__shared__ real2 w[PACKED_YSIZE];
for (int i = threadIdx.x; i < PACKED_YSIZE; i += blockDim.x)
w[i] = make_real2(cos(i*2*M_PI/YSIZE), sin(i*2*M_PI/YSIZE));
#else
__shared__ real2 w[PACKED_ZSIZE];
for (int i = threadIdx.x; i < PACKED_ZSIZE; i += blockDim.x)
w[i] = make_real2(cos(i*2*M_PI/ZSIZE), sin(i*2*M_PI/ZSIZE));
#endif
__syncthreads();
// Transform the data.
const int gridSize = PACKED_XSIZE*PACKED_YSIZE*PACKED_ZSIZE;
for (int index = blockIdx.x*blockDim.x+threadIdx.x; index < gridSize; index += blockDim.x*gridDim.x) {
int x = index/(PACKED_YSIZE*PACKED_ZSIZE);
int remainder = index-x*(PACKED_YSIZE*PACKED_ZSIZE);
int y = remainder/PACKED_ZSIZE;
int z = remainder-y*PACKED_ZSIZE;
int xp = (x == 0 ? 0 : PACKED_XSIZE-x);
int yp = (y == 0 ? 0 : PACKED_YSIZE-y);
int zp = (z == 0 ? 0 : PACKED_ZSIZE-z);
real2 z1 = loadComplexValue(in, x, y, z);
#if PACKED_AXIS == 0
real2 wfac = w[x];
real2 z2 = loadComplexValue(in, PACKED_XSIZE-x, yp, zp);
#elif PACKED_AXIS == 1
real2 wfac = w[y];
real2 z2 = loadComplexValue(in, xp, PACKED_YSIZE-y, zp);
#else
real2 wfac = w[z];
real2 z2 = loadComplexValue(in, xp, yp, PACKED_ZSIZE-z);
#endif
real2 even = make_real2((z1.x+z2.x)/2, (z1.y-z2.y)/2);
real2 odd = make_real2((z1.x-z2.x)/2, (z1.y+z2.y)/2);
odd = make_real2(odd.x*wfac.x-odd.y*wfac.y, odd.y*wfac.x+odd.x*wfac.y);
out[x*PACKED_YSIZE*PACKED_ZSIZE+y*PACKED_ZSIZE+z] = make_real2(even.x-odd.y, even.y+odd.x);
}
}
/**
* Split the data back into a full sized, real grid after an inverse transform.
*/
extern "C" __global__ void unpackBackwardData(const real2* __restrict__ in, real* __restrict__ out) {
const int gridSize = PACKED_XSIZE*PACKED_YSIZE*PACKED_ZSIZE;
for (int index = blockIdx.x*blockDim.x+threadIdx.x; index < gridSize; index += blockDim.x*gridDim.x) {
int x = index/(PACKED_YSIZE*PACKED_ZSIZE);
int remainder = index-x*(PACKED_YSIZE*PACKED_ZSIZE);
int y = remainder/PACKED_ZSIZE;
int z = remainder-y*PACKED_ZSIZE;
real2 value = 2*in[index];
#if PACKED_AXIS == 0
out[2*x*YSIZE*ZSIZE+y*ZSIZE+z] = value.x;
out[(2*x+1)*YSIZE*ZSIZE+y*ZSIZE+z] = value.y;
#elif PACKED_AXIS == 1
out[x*YSIZE*ZSIZE+2*y*ZSIZE+z] = value.x;
out[x*YSIZE*ZSIZE+(2*y+1)*ZSIZE+z] = value.y;
#else
out[x*YSIZE*ZSIZE+y*ZSIZE+2*z] = value.x;
out[x*YSIZE*ZSIZE+y*ZSIZE+(2*z+1)] = value.y;
#endif
}
}
......@@ -31,14 +31,13 @@
* -------------------------------------------------------------------------- */
/**
* This tests the Hip implementation of sorting.
* This tests the Hip implementation of FFT.
*/
#include "openmm/internal/AssertionUtilities.h"
#include "HipArray.h"
#include "HipContext.h"
#include "HipFFT3D.h"
#include "HipSort.h"
#include "fftpack.h"
#include "sfmt/SFMT.h"
#include "openmm/System.h"
......@@ -52,7 +51,7 @@ using namespace std;
static HipPlatform platform;
template <class Real2>
void testTransform(bool realToComplex, int xsize, int ysize, int zsize) {
void testTransform(bool realToComplex, int xsize, int ysize, int zsize, double eps = 1) {
System system;
system.addParticle(0.0);
HipPlatform::PlatformData platformData(NULL, system, "", "true", platform.getPropertyDefaultValue("HipPrecision"), "false",
......@@ -60,6 +59,11 @@ void testTransform(bool realToComplex, int xsize, int ysize, int zsize) {
platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false", 1, NULL);
HipContext& context = *platformData.contexts[0];
context.initialize();
context.setAsCurrent();
xsize = context.findLegalFFTDimension(xsize);
ysize = context.findLegalFFTDimension(ysize);
zsize = context.findLegalFFTDimension(zsize);
cout << "realToComplex: " << realToComplex << " xsize: " << xsize << " ysize: " << ysize << " zsize: " << zsize << endl;
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
vector<Real2> original(xsize*ysize*zsize);
......@@ -80,11 +84,11 @@ void testTransform(bool realToComplex, int xsize, int ysize, int zsize) {
HipArray grid1(context, original.size(), sizeof(Real2), "grid1");
HipArray grid2(context, original.size(), sizeof(Real2), "grid2");
grid1.upload(original);
HipFFT3D fft(context, xsize, ysize, zsize, realToComplex);
HipFFT3D fft(context, xsize, ysize, zsize, realToComplex, context.getCurrentStream(), grid1, grid2);
// Perform a forward FFT, then verify the result is correct.
fft.execFFT(grid1, grid2, true);
fft.execFFT(true);
vector<Real2> result;
grid2.download(result);
fftpack_t plan;
......@@ -96,23 +100,24 @@ void testTransform(bool realToComplex, int xsize, int ysize, int zsize) {
for (int z = 0; z < outputZSize; z++) {
int index1 = x*ysize*zsize + y*zsize + z;
int index2 = x*ysize*outputZSize + y*outputZSize + z;
ASSERT_EQUAL_TOL(reference[index1].re, result[index2].x, 1e-3);
ASSERT_EQUAL_TOL(reference[index1].im, result[index2].y, 1e-3);
ASSERT_EQUAL_TOL(reference[index1].re, result[index2].x, 1e-3 * eps);
ASSERT_EQUAL_TOL(reference[index1].im, result[index2].y, 1e-3 * eps);
}
fftpack_destroy(plan);
// Perform a backward transform and see if we get the original values.
fft.execFFT(grid2, grid1, false);
fft.execFFT(false);
grid1.download(result);
double scale = 1.0/(xsize*ysize*zsize);
int valuesToCheck = (realToComplex ? original.size()/2 : original.size());
for (int i = 0; i < valuesToCheck; ++i) {
ASSERT_EQUAL_TOL(original[i].x, scale*result[i].x, 1e-4);
ASSERT_EQUAL_TOL(original[i].y, scale*result[i].y, 1e-4);
ASSERT_EQUAL_TOL(original[i].x, scale*result[i].x, 1e-4 * eps);
ASSERT_EQUAL_TOL(original[i].y, scale*result[i].y, 1e-4 * eps);
}
}
int main(int argc, char* argv[]) {
try {
if (argc > 1)
......@@ -123,6 +128,17 @@ int main(int argc, char* argv[]) {
testTransform<double2>(true, 25, 28, 25);
testTransform<double2>(true, 25, 25, 28);
testTransform<double2>(true, 21, 25, 27);
testTransform<double2>(true, 49, 98, 14);
testTransform<double2>(true, 7, 21, 98);
testTransform<double2>(true, 98, 21, 21);
testTransform<double2>(true, 18, 98, 6);
testTransform<double2>(true, 50, 50, 50);
testTransform<double2>(true, 60, 60, 60);
testTransform<double2>(false, 64, 64, 64);
testTransform<double2>(false, 100, 140, 88);
testTransform<double2>(true, 120, 243, 120);
testTransform<double2>(true, 216, 216, 116);
testTransform<double2>(true, 98, 98, 98);
}
else {
testTransform<float2>(false, 28, 25, 30);
......@@ -130,6 +146,17 @@ int main(int argc, char* argv[]) {
testTransform<float2>(true, 25, 28, 25);
testTransform<float2>(true, 25, 25, 28);
testTransform<float2>(true, 21, 25, 27);
testTransform<float2>(true, 49, 98, 14);
testTransform<float2>(true, 7, 21, 98);
testTransform<float2>(true, 98, 21, 21);
testTransform<float2>(true, 18, 98, 6);
testTransform<float2>(true, 50, 50, 50);
testTransform<float2>(true, 60, 60, 60);
testTransform<float2>(false, 64, 64, 64);
testTransform<float2>(false, 100, 140, 88, 1e+1);
testTransform<float2>(true, 120, 243, 120, 1e+1);
testTransform<float2>(true, 216, 216, 116, 1e+1);
testTransform<float2>(true, 98, 98, 98, 1e+1);
}
}
catch(const exception& e) {
......
......@@ -37,7 +37,6 @@
#include "openmm/internal/AmoebaVdwForceImpl.h"
#include "openmm/internal/NonbondedForceImpl.h"
#include "HipBondedUtilities.h"
#include "HipFFT3D.h"
#include "HipForceInfo.h"
#include "HipKernelSources.h"
#include "SimTKOpenMMRealType.h"
......@@ -52,70 +51,27 @@
using namespace OpenMM;
using namespace std;
static void setPeriodicBoxArgs(ComputeContext& cc, ComputeKernel kernel, int index) {
Vec3 a, b, c;
cc.getPeriodicBoxVectors(a, b, c);
if (cc.getUseDoublePrecision()) {
kernel->setArg(index++, mm_double4(a[0], b[1], c[2], 0.0));
kernel->setArg(index++, mm_double4(1.0/a[0], 1.0/b[1], 1.0/c[2], 0.0));
kernel->setArg(index++, mm_double4(a[0], a[1], a[2], 0.0));
kernel->setArg(index++, mm_double4(b[0], b[1], b[2], 0.0));
kernel->setArg(index, mm_double4(c[0], c[1], c[2], 0.0));
}
else {
kernel->setArg(index++, mm_float4((float) a[0], (float) b[1], (float) c[2], 0.0f));
kernel->setArg(index++, mm_float4(1.0f/(float) a[0], 1.0f/(float) b[1], 1.0f/(float) c[2], 0.0f));
kernel->setArg(index++, mm_float4((float) a[0], (float) a[1], (float) a[2], 0.0f));
kernel->setArg(index++, mm_float4((float) b[0], (float) b[1], (float) b[2], 0.0f));
kernel->setArg(index, mm_float4((float) c[0], (float) c[1], (float) c[2], 0.0f));
}
}
/* -------------------------------------------------------------------------- *
* AmoebaMultipole *
* -------------------------------------------------------------------------- */
HipCalcAmoebaMultipoleForceKernel::~HipCalcAmoebaMultipoleForceKernel() {
cc.setAsCurrent();
if (hasInitializedFFT)
hipfftDestroy(fft);
if (fft != NULL)
delete fft;
}
void HipCalcAmoebaMultipoleForceKernel::initialize(const System& system, const AmoebaMultipoleForce& force) {
CommonCalcAmoebaMultipoleForceKernel::initialize(system, force);
if (usePME) {
hipfftResult result = hipfftPlan3d(&fft, gridSizeX, gridSizeY, gridSizeZ, cc.getUseDoublePrecision() ? HIPFFT_Z2Z : HIPFFT_C2C);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error initializing FFT: "+cc.intToString(result));
hasInitializedFFT = true;
HipArray& grid1 = cu.unwrap(pmeGrid1);
HipArray& grid2 = cu.unwrap(pmeGrid2);
fft = cu.createFFT(gridSizeX, gridSizeY, gridSizeZ, false, cu.getCurrentStream(), grid1, grid2);
}
}
void HipCalcAmoebaMultipoleForceKernel::computeFFT(bool forward) {
HipArray& grid1 = dynamic_cast<HipContext&>(cc).unwrap(pmeGrid1);
HipArray& grid2 = dynamic_cast<HipContext&>(cc).unwrap(pmeGrid2);
if (forward) {
if (cc.getUseDoublePrecision()) {
hipfftResult result = hipfftExecZ2Z(fft, (double2*) grid1.getDevicePointer(), (double2*) grid2.getDevicePointer(), HIPFFT_FORWARD);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cc.intToString(result));
} else {
hipfftResult result = hipfftExecC2C(fft, (float2*) grid1.getDevicePointer(), (float2*) grid2.getDevicePointer(), HIPFFT_FORWARD);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cc.intToString(result));
}
}
else {
if (cc.getUseDoublePrecision()) {
hipfftResult result = hipfftExecZ2Z(fft, (double2*) grid2.getDevicePointer(), (double2*) grid1.getDevicePointer(), HIPFFT_BACKWARD);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cc.intToString(result));
} else {
hipfftResult result = hipfftExecC2C(fft, (float2*) grid2.getDevicePointer(), (float2*) grid1.getDevicePointer(), HIPFFT_BACKWARD);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cc.intToString(result));
}
}
fft->execFFT(forward);
}
/* -------------------------------------------------------------------------- *
......@@ -126,60 +82,29 @@ HipCalcHippoNonbondedForceKernel::~HipCalcHippoNonbondedForceKernel() {
cc.setAsCurrent();
if (sort != NULL)
delete sort;
if (hasInitializedFFT) {
hipfftDestroy(fftForward);
hipfftDestroy(fftBackward);
hipfftDestroy(dfftForward);
hipfftDestroy(dfftBackward);
}
if (fft != NULL)
delete fft;
if (dfft != NULL)
delete dfft;
}
void HipCalcHippoNonbondedForceKernel::initialize(const System& system, const HippoNonbondedForce& force) {
CommonCalcHippoNonbondedForceKernel::initialize(system, force);
if (usePME) {
sort = new HipSort(cu, new SortTrait(), cc.getNumAtoms());
hipfftResult result = hipfftPlan3d(&fftForward, gridSizeX, gridSizeY, gridSizeZ, cc.getUseDoublePrecision() ? HIPFFT_D2Z : HIPFFT_R2C);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error initializing FFT: "+cc.intToString(result));
result = hipfftPlan3d(&fftBackward, gridSizeX, gridSizeY, gridSizeZ, cc.getUseDoublePrecision() ? HIPFFT_Z2D : HIPFFT_C2R);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error initializing FFT: "+cc.intToString(result));
result = hipfftPlan3d(&dfftForward, dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, cc.getUseDoublePrecision() ? HIPFFT_D2Z : HIPFFT_R2C);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error initializing FFT: "+cc.intToString(result));
result = hipfftPlan3d(&dfftBackward, dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, cc.getUseDoublePrecision() ? HIPFFT_Z2D : HIPFFT_C2R);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error initializing FFT: "+cc.intToString(result));
hasInitializedFFT = true;
HipArray& grid1 = cu.unwrap(pmeGrid1);
HipArray& grid2 = cu.unwrap(pmeGrid2);
fft = cu.createFFT(gridSizeX, gridSizeY, gridSizeZ, true, cu.getCurrentStream(), grid1, grid2);
dfft = cu.createFFT(dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true, cu.getCurrentStream(), grid1, grid2);
}
}
void HipCalcHippoNonbondedForceKernel::computeFFT(bool forward, bool dispersion) {
HipArray& grid1 = dynamic_cast<HipContext&>(cc).unwrap(pmeGrid1);
HipArray& grid2 = dynamic_cast<HipContext&>(cc).unwrap(pmeGrid2);
if (forward) {
hipfftHandle fft = dispersion ? dfftForward : fftForward;
if (cc.getUseDoublePrecision()) {
hipfftResult result = hipfftExecD2Z(fft, (double*) grid1.getDevicePointer(), (double2*) grid2.getDevicePointer());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cc.intToString(result));
} else {
hipfftResult result = hipfftExecR2C(fft, (float*) grid1.getDevicePointer(), (float2*) grid2.getDevicePointer());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cc.intToString(result));
}
if (dispersion) {
dfft->execFFT(forward);
}
else {
hipfftHandle fft = dispersion ? dfftBackward : fftBackward;
if (cc.getUseDoublePrecision()) {
hipfftResult result = hipfftExecZ2D(fft, (double2*) grid2.getDevicePointer(), (double*) grid1.getDevicePointer());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cc.intToString(result));
} else {
hipfftResult result = hipfftExecC2R(fft, (float2*) grid2.getDevicePointer(), (float*) grid1.getDevicePointer());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing FFT: "+cc.intToString(result));
}
fft->execFFT(forward);
}
}
......
......@@ -34,8 +34,8 @@
#include "HipContext.h"
#include "HipNonbondedUtilities.h"
#include "HipSort.h"
#include "HipFFT3D.h"
#include "AmoebaCommonKernels.h"
#include <hipfft.h>
namespace OpenMM {
......@@ -45,7 +45,7 @@ namespace OpenMM {
class HipCalcAmoebaMultipoleForceKernel : public CommonCalcAmoebaMultipoleForceKernel {
public:
HipCalcAmoebaMultipoleForceKernel(const std::string& name, const Platform& platform, HipContext& cu, const System& system) :
CommonCalcAmoebaMultipoleForceKernel(name, platform, cu, system), hasInitializedFFT(false) {
CommonCalcAmoebaMultipoleForceKernel(name, platform, cu, system), cu(cu), fft(NULL) {
}
~HipCalcAmoebaMultipoleForceKernel();
/**
......@@ -66,8 +66,8 @@ public:
return cc.getUseDoublePrecision() || !dynamic_cast<HipContext&>(cc).getSupportsHardwareFloatGlobalAtomicAdd();
}
private:
bool hasInitializedFFT;
hipfftHandle fft;
HipContext& cu;
HipFFT3D* fft;
};
/**
......@@ -76,7 +76,7 @@ private:
class HipCalcHippoNonbondedForceKernel : public CommonCalcHippoNonbondedForceKernel {
public:
HipCalcHippoNonbondedForceKernel(const std::string& name, const Platform& platform, HipContext& cu, const System& system) :
CommonCalcHippoNonbondedForceKernel(name, platform, cu, system), cu(cu), sort(NULL), hasInitializedFFT(false) {
CommonCalcHippoNonbondedForceKernel(name, platform, cu, system), cu(cu), sort(NULL), fft(NULL), dfft(NULL) {
}
~HipCalcHippoNonbondedForceKernel();
/**
......@@ -112,9 +112,9 @@ private:
const char* getSortKey() const {return "value.y";}
};
HipContext& cu;
bool hasInitializedFFT;
HipSort* sort;
hipfftHandle fftForward, fftBackward, dfftForward, dfftBackward;
HipFFT3D* fft;
HipFFT3D* dfft;
};
} // namespace OpenMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment