Commit 284ea303 authored by one's avatar one
Browse files

[OpenMM] Add a patch file for gfx936

parent 18467adf
diff --git a/examples/benchmarks/benchmark.py b/examples/benchmarks/benchmark.py
index f3eb95a88..f9fea45c2 100644
--- a/examples/benchmarks/benchmark.py
+++ b/examples/benchmarks/benchmark.py
@@ -354,6 +354,8 @@ def runOneTest(testName, options):
initialSteps = 250
if options.disable_pme_stream:
properties['DisablePmeStream'] = 'true'
+ if options.platform == 'HIP' and options.fft_backend is not None and 'FFTBackend' in platform.getPropertyNames():
+ properties['FFTBackend'] = options.fft_backend
if options.opencl_platform is not None and 'OpenCLPlatformIndex' in platform.getPropertyNames():
properties['OpenCLPlatformIndex'] = options.opencl_platform
if (options.precision is not None) and ('Precision' in platform.getPropertyNames()):
@@ -481,6 +483,7 @@ parser.add_argument('--polarization', default='mutual', dest='polarization', cho
parser.add_argument('--mutual-epsilon', default=1e-5, dest='epsilon', type=float, help='mutual induced epsilon for AMOEBA [default: 1e-5]')
parser.add_argument('--bond-constraints', default='hbonds', dest='bond_constraints', help=f'hbonds: constrain bonds to hydrogen, use 1.5*amu H mass; allbonds: constrain all bonds, use 4*amu H mass, and use larger timestep. This option is ignored for AMOEBA: {BOND_CONSTRAINTS} [default: hbonds]')
parser.add_argument('--disable-pme-stream', default=False, action='store_true', dest='disable_pme_stream', help='disable use of a separate GPU stream for PME')
+parser.add_argument('--fft-backend', default=None, dest='fft_backend', help='FFT backend for HIP platform (for example: vkfft or hipfft)')
parser.add_argument('--device', default=None, dest='device', help='device index for CUDA, HIP, or OpenCL')
parser.add_argument('--opencl-platform', default=None, dest='opencl_platform', help='platform index for OpenCL')
parser.add_argument('--precision', default='single', dest='precision', help=f'precision modes for CUDA, HIP, or OpenCL: {PRECISIONS} [default: single]')
diff --git a/platforms/common/src/CommonCalcNonbondedForce.cpp b/platforms/common/src/CommonCalcNonbondedForce.cpp
index bab0679b7..e75dabdc0 100644
--- a/platforms/common/src/CommonCalcNonbondedForce.cpp
+++ b/platforms/common/src/CommonCalcNonbondedForce.cpp
@@ -962,7 +962,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
// Execute the reciprocal space kernels.
if (hasCoulomb) {
- if (stepsToSort <= 0 || doLJPME || cc.getNumAtoms() > 15000) {
+ if (stepsToSort <= 0 || doLJPME) {
setPeriodicBoxArgs(cc, pmeGridIndexKernel, 2);
if (cc.getUseDoublePrecision()) {
pmeGridIndexKernel->setArg(7, recipBoxVectors[0]);
diff --git a/platforms/hip/CMakeLists.txt b/platforms/hip/CMakeLists.txt
index 7c8dfca8f..2ca6f35d7 100644
--- a/platforms/hip/CMakeLists.txt
+++ b/platforms/hip/CMakeLists.txt
@@ -12,7 +12,23 @@
# libOpenMMHIP_static.a
#----------------------------------------------------
-FIND_PACKAGE(HIPRTC CONFIG)
+IF(NOT TARGET hiprtc::hiprtc)
+ add_library(hiprtc::hiprtc SHARED IMPORTED)
+ set_target_properties(hiprtc::hiprtc PROPERTIES
+ IMPORTED_LOCATION "/opt/dtk/hip/lib/libhiprtc.so"
+ INTERFACE_INCLUDE_DIRECTORIES "/opt/dtk/hip/include"
+ )
+ENDIF()
+
+FIND_PACKAGE(HIPFFT CONFIG QUIET)
+
+IF(NOT TARGET hiprtc::hiprtc)
+ add_library(hiprtc::hiprtc SHARED IMPORTED)
+ set_target_properties(hiprtc::hiprtc PROPERTIES
+ IMPORTED_LOCATION "/opt/dtk/hip/lib/libhiprtc.so"
+ INTERFACE_INCLUDE_DIRECTORIES "/opt/dtk/hip/include"
+ )
+ENDIF()
SET(OPENMM_BUILD_HIP_TESTS TRUE CACHE BOOL "Whether to build HIP test cases")
IF(BUILD_TESTING AND OPENMM_BUILD_HIP_TESTS)
@@ -103,6 +119,10 @@ IF(OPENMM_BUILD_SHARED_LIB)
TARGET_LINK_LIBRARIES(${SHARED_TARGET} PUBLIC ${OPENMM_LIBRARY_NAME} hip::host hiprtc::hiprtc)
SET_TARGET_PROPERTIES(${SHARED_TARGET} PROPERTIES COMPILE_FLAGS "${EXTRA_COMPILE_FLAGS} -DOPENMM_COMMON_BUILDING_SHARED_LIBRARY")
SET_TARGET_PROPERTIES(${SHARED_TARGET} PROPERTIES LINK_FLAGS "${EXTRA_LINK_FLAGS}")
+ IF(HIPFFT_FOUND)
+ TARGET_LINK_LIBRARIES(${SHARED_TARGET} PUBLIC hip::hipfft)
+ TARGET_COMPILE_OPTIONS(${SHARED_TARGET} PUBLIC "-DOPENMM_HIP_WITH_HIPFFT")
+ ENDIF()
INSTALL_TARGETS(/lib/plugins RUNTIME_DIRECTORY /lib/plugins ${SHARED_TARGET})
ENDIF(OPENMM_BUILD_SHARED_LIB)
@@ -116,6 +136,10 @@ IF(OPENMM_BUILD_STATIC_LIB)
TARGET_LINK_LIBRARIES(${STATIC_TARGET} ${OPENMM_LIBRARY_NAME} hip::host hiprtc::hiprtc)
SET_TARGET_PROPERTIES(${STATIC_TARGET} PROPERTIES COMPILE_FLAGS "${EXTRA_COMPILE_FLAGS} -DOPENMM_COMMON_BUILDING_STATIC_LIBRARY")
SET_TARGET_PROPERTIES(${STATIC_TARGET} PROPERTIES LINK_FLAGS "${EXTRA_LINK_FLAGS}")
+ IF(HIPFFT_FOUND)
+ TARGET_LINK_LIBRARIES(${STATIC_TARGET} PUBLIC hip::hipfft)
+ TARGET_COMPILE_OPTIONS(${STATIC_TARGET} PUBLIC "-DOPENMM_HIP_WITH_HIPFFT")
+ ENDIF()
INSTALL_TARGETS(/lib/plugins RUNTIME_DIRECTORY /lib/plugins ${STATIC_TARGET})
ENDIF(OPENMM_BUILD_STATIC_LIB)
diff --git a/platforms/hip/include/HipHipFFT3D.h b/platforms/hip/include/HipHipFFT3D.h
new file mode 100644
index 000000000..af509e4fc
--- /dev/null
+++ b/platforms/hip/include/HipHipFFT3D.h
@@ -0,0 +1,60 @@
+#ifndef __OPENMM_HIPHIPFFT3D_H__
+#define __OPENMM_HIPHIPFFT3D_H__
+
+/* -------------------------------------------------------------------------- *
+ * OpenMM *
+ * -------------------------------------------------------------------------- *
+ * This is part of the OpenMM molecular simulation toolkit. *
+ * See https://openmm.org/development. *
+ * *
+ * Portions copyright (c) 2009-2026 Stanford University and the Authors. *
+ * Contributors: *
+ * *
+ * This program is free software: you can redistribute it and/or modify *
+ * it under the terms of the GNU Lesser General Public License as published *
+ * by the Free Software Foundation, either version 3 of the License, or *
+ * (at your option) any later version. *
+ * *
+ * This program is distributed in the hope that it will be useful, *
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of *
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
+ * GNU Lesser General Public License for more details. *
+ * *
+ * You should have received a copy of the GNU Lesser General Public License *
+ * along with this program. If not, see <http://www.gnu.org/licenses/>. *
+ * -------------------------------------------------------------------------- */
+
+#include "openmm/common/windowsExportCommon.h"
+#include "openmm/common/FFT3D.h"
+#include "openmm/common/ArrayInterface.h"
+
+#ifdef OPENMM_HIP_WITH_HIPFFT
+#if __has_include(<hipfft/hipfft.h>)
+#include <hipfft/hipfft.h>
+#else
+#include <hipfft.h>
+#endif
+#endif
+
+namespace OpenMM {
+
+class HipContext;
+
+class OPENMM_EXPORT_COMMON HipHipFFT3D : public FFT3DImpl {
+public:
+ HipHipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool realToComplex);
+ ~HipHipFFT3D();
+ void execFFT(ArrayInterface& in, ArrayInterface& out, bool forward=true) override;
+private:
+ HipContext& context;
+ bool realToComplex;
+ bool doublePrecision;
+#ifdef OPENMM_HIP_WITH_HIPFFT
+ hipfftHandle forwardPlan;
+ hipfftHandle backwardPlan;
+#endif
+};
+
+} // namespace OpenMM
+
+#endif // __OPENMM_HIPHIPFFT3D_H__
diff --git a/platforms/hip/include/HipPlatform.h b/platforms/hip/include/HipPlatform.h
index a0665d45c..43eba16af 100644
--- a/platforms/hip/include/HipPlatform.h
+++ b/platforms/hip/include/HipPlatform.h
@@ -111,13 +111,21 @@ public:
static const std::string key = "DeterministicForces";
return key;
}
+ /**
+ * This is the name of the parameter for selecting which FFT backend to use.
+ */
+ static const std::string& HipFFTBackend() {
+ static const std::string key = "FFTBackend";
+ return key;
+ }
};
class OPENMM_EXPORT_COMMON HipPlatform::PlatformData {
public:
PlatformData(ContextImpl* context, const System& system, const std::string& deviceIndexProperty, const std::string& blockingProperty, const std::string& precisionProperty,
const std::string& cpuPmeProperty, const std::string& tempProperty,
- const std::string& pmeStreamProperty, const std::string& deterministicForcesProperty, int numThreads, ContextImpl* originalContext);
+ const std::string& pmeStreamProperty, const std::string& deterministicForcesProperty, const std::string& fftBackendProperty,
+ int numThreads, ContextImpl* originalContext);
~PlatformData();
void initializeContexts(const System& system);
void syncContexts();
diff --git a/platforms/hip/src/HipContext.cpp b/platforms/hip/src/HipContext.cpp
index 0050a8c46..dc4290b3c 100644
--- a/platforms/hip/src/HipContext.cpp
+++ b/platforms/hip/src/HipContext.cpp
@@ -31,6 +31,7 @@
#include "HipBondedUtilities.h"
#include "HipEvent.h"
#include "HipFFT3D.h"
+#include "HipHipFFT3D.h"
#include "HipIntegrationUtilities.h"
#include "HipKernels.h"
#include "HipKernelSources.h"
@@ -80,6 +81,12 @@
using namespace OpenMM;
using namespace std;
+static string normalizeFftBackendName(const string& value) {
+ string result = value;
+ transform(result.begin(), result.end(), result.begin(), ::tolower);
+ return result;
+}
+
const int HipContext::ThreadBlockSize = 64;
const int HipContext::TileSize = 32;
bool HipContext::hasInitializedHip = false;
@@ -712,7 +719,12 @@ ComputeSort HipContext::createSort(ComputeSortImpl::SortTrait* trait, unsigned i
}
FFT3D HipContext::createFFT(int xsize, int ysize, int zsize, bool realToComplex) {
- return FFT3D(new HipFFT3D(*this, xsize, ysize, zsize, realToComplex));
+ string backend = normalizeFftBackendName(platformData.propertyValues[HipPlatform::HipFFTBackend()]);
+ if (backend.empty() || backend == "vkfft")
+ return FFT3D(new HipFFT3D(*this, xsize, ysize, zsize, realToComplex));
+ if (backend == "hipfft")
+ return FFT3D(new HipHipFFT3D(*this, xsize, ysize, zsize, realToComplex));
+ throw OpenMMException("Illegal value for FFTBackend: "+backend);
}
int HipContext::findLegalFFTDimension(int minimum) {
diff --git a/platforms/hip/src/HipHipFFT3D.cpp b/platforms/hip/src/HipHipFFT3D.cpp
new file mode 100644
index 000000000..a0a408c00
--- /dev/null
+++ b/platforms/hip/src/HipHipFFT3D.cpp
@@ -0,0 +1,100 @@
+/* -------------------------------------------------------------------------- *
+ * OpenMM *
+ * -------------------------------------------------------------------------- *
+ * This is part of the OpenMM molecular simulation toolkit. *
+ * See https://openmm.org/development. *
+ * *
+ * Portions copyright (c) 2009-2026 Stanford University and the Authors. *
+ * Contributors: *
+ * *
+ * This program is free software: you can redistribute it and/or modify *
+ * it under the terms of the GNU Lesser General Public License as published *
+ * by the Free Software Foundation, either version 3 of the License, or *
+ * (at your option) any later version. *
+ * *
+ * This program is distributed in the hope that it will be useful, *
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of *
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
+ * GNU Lesser General Public License for more details. *
+ * *
+ * You should have received a copy of the GNU Lesser General Public License *
+ * along with this program. If not, see <http://www.gnu.org/licenses/>. *
+ * -------------------------------------------------------------------------- */
+
+#include "HipHipFFT3D.h"
+#include "HipContext.h"
+#include "openmm/OpenMMException.h"
+
+using namespace OpenMM;
+using namespace std;
+
+HipHipFFT3D::HipHipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool realToComplex) :
+ context(context), realToComplex(realToComplex), doublePrecision(context.getUseDoublePrecision())
+#ifdef OPENMM_HIP_WITH_HIPFFT
+ , forwardPlan(0), backwardPlan(0)
+#endif
+{
+#ifndef OPENMM_HIP_WITH_HIPFFT
+ throw OpenMMException("hipFFT backend requested but OpenMM was built without HIPFFT support");
+#else
+ hipfftType forwardType, backwardType;
+ if (realToComplex) {
+ forwardType = doublePrecision ? HIPFFT_D2Z : HIPFFT_R2C;
+ backwardType = doublePrecision ? HIPFFT_Z2D : HIPFFT_C2R;
+ }
+ else {
+ forwardType = doublePrecision ? HIPFFT_Z2Z : HIPFFT_C2C;
+ backwardType = forwardType;
+ }
+ hipfftResult result = hipfftPlan3d(&forwardPlan, xsize, ysize, zsize, forwardType);
+ if (result != HIPFFT_SUCCESS)
+ throw OpenMMException("Error initializing hipFFT forward plan: "+context.intToString(result));
+ result = hipfftPlan3d(&backwardPlan, xsize, ysize, zsize, backwardType);
+ if (result != HIPFFT_SUCCESS) {
+ hipfftDestroy(forwardPlan);
+ throw OpenMMException("Error initializing hipFFT backward plan: "+context.intToString(result));
+ }
+#endif
+}
+
+HipHipFFT3D::~HipHipFFT3D() {
+#ifdef OPENMM_HIP_WITH_HIPFFT
+ if (forwardPlan != 0)
+ hipfftDestroy(forwardPlan);
+ if (backwardPlan != 0)
+ hipfftDestroy(backwardPlan);
+#endif
+}
+
+void HipHipFFT3D::execFFT(ArrayInterface& in, ArrayInterface& out, bool forward) {
+#ifndef OPENMM_HIP_WITH_HIPFFT
+ throw OpenMMException("hipFFT backend requested but OpenMM was built without HIPFFT support");
+#else
+ hipfftHandle plan = forward ? forwardPlan : backwardPlan;
+ hipfftResult result = hipfftSetStream(plan, context.getCurrentStream());
+ if (result != HIPFFT_SUCCESS)
+ throw OpenMMException("Error setting hipFFT stream: "+context.intToString(result));
+ if (realToComplex) {
+ if (forward) {
+ if (doublePrecision)
+ result = hipfftExecD2Z(plan, (double*) context.unwrap(in).getDevicePointer(), (double2*) context.unwrap(out).getDevicePointer());
+ else
+ result = hipfftExecR2C(plan, (float*) context.unwrap(in).getDevicePointer(), (float2*) context.unwrap(out).getDevicePointer());
+ }
+ else {
+ if (doublePrecision)
+ result = hipfftExecZ2D(plan, (double2*) context.unwrap(in).getDevicePointer(), (double*) context.unwrap(out).getDevicePointer());
+ else
+ result = hipfftExecC2R(plan, (float2*) context.unwrap(in).getDevicePointer(), (float*) context.unwrap(out).getDevicePointer());
+ }
+ }
+ else {
+ if (doublePrecision)
+ result = hipfftExecZ2Z(plan, (double2*) context.unwrap(in).getDevicePointer(), (double2*) context.unwrap(out).getDevicePointer(), forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD);
+ else
+ result = hipfftExecC2C(plan, (float2*) context.unwrap(in).getDevicePointer(), (float2*) context.unwrap(out).getDevicePointer(), forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD);
+ }
+ if (result != HIPFFT_SUCCESS)
+ throw OpenMMException("Error executing hipFFT: "+context.intToString(result));
+#endif
+}
diff --git a/platforms/hip/src/HipNonbondedUtilities.cpp b/platforms/hip/src/HipNonbondedUtilities.cpp
index 01cc7f826..60548054d 100644
--- a/platforms/hip/src/HipNonbondedUtilities.cpp
+++ b/platforms/hip/src/HipNonbondedUtilities.cpp
@@ -65,16 +65,15 @@ HipNonbondedUtilities::HipNonbondedUtilities(HipContext& context) : context(cont
string errorMessage = "Error initializing nonbonded utilities";
CHECK_RESULT(hipEventCreateWithFlags(&downloadCountEvent, context.getEventFlags()));
CHECK_RESULT(hipHostMalloc((void**) &pinnedCountBuffer, 2*sizeof(unsigned int), context.getHostMallocFlags()));
- numForceThreadBlocks = 5*4*context.getMultiprocessors();
- forceThreadBlockSize = 64;
- findInteractingBlocksThreadBlockSize = context.getSIMDWidth();
+ numForceThreadBlocks = 16*4*context.getMultiprocessors();
+ forceThreadBlockSize = 256;
+ findInteractingBlocksThreadBlockSize = 128;
// When building the neighbor list, we can optionally use large blocks (32 * warpSize atoms) to
// accelerate the process. This makes building the neighbor list faster, but it prevents
// us from sorting atom blocks by size, which leads to a slightly less efficient neighbor
// list. We guess based on system size which will be faster.
-
- useLargeBlocks = (context.getNumAtoms() > 90000);
+ useLargeBlocks = false;
setKernelSource(HipKernelSources::nonbonded);
}
@@ -400,7 +399,7 @@ double HipNonbondedUtilities::getMaxCutoffDistance() {
}
double HipNonbondedUtilities::padCutoff(double cutoff) {
- double padding = (usePadding ? 0.08*cutoff : 0.0);
+ double padding = (usePadding ? 0.12*cutoff : 0.0);
return cutoff+padding;
}
diff --git a/platforms/hip/src/HipPlatform.cpp b/platforms/hip/src/HipPlatform.cpp
index e6e68e2dd..3c3936da6 100644
--- a/platforms/hip/src/HipPlatform.cpp
+++ b/platforms/hip/src/HipPlatform.cpp
@@ -69,6 +69,7 @@ HipPlatform::HipPlatform() {
deprecatedPropertyReplacements["HipTempDirectory"] = HipTempDirectory();
deprecatedPropertyReplacements["HipDisablePmeStream"] = HipDisablePmeStream();
deprecatedPropertyReplacements["HipDeterministicForces"] = HipDeterministicForces();
+ deprecatedPropertyReplacements["HipFFTBackend"] = HipFFTBackend();
HipKernelFactory* factory = new HipKernelFactory();
registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory);
registerKernelFactory(UpdateStateDataKernel::Name(), factory);
@@ -122,6 +123,7 @@ HipPlatform::HipPlatform() {
platformProperties.push_back(HipTempDirectory());
platformProperties.push_back(HipDisablePmeStream());
platformProperties.push_back(HipDeterministicForces());
+ platformProperties.push_back(HipFFTBackend());
setPropertyDefaultValue(HipDeviceIndex(), "");
setPropertyDefaultValue(HipDeviceName(), "");
setPropertyDefaultValue(HipUseBlockingSync(), "true");
@@ -129,6 +131,7 @@ HipPlatform::HipPlatform() {
setPropertyDefaultValue(HipUseCpuPme(), "false");
setPropertyDefaultValue(HipDisablePmeStream(), "false");
setPropertyDefaultValue(HipDeterministicForces(), "false");
+ setPropertyDefaultValue(HipFFTBackend(), "vkfft");
#ifdef _MSC_VER
setPropertyDefaultValue(HipTempDirectory(), string(getenv("TEMP")));
#else
@@ -213,11 +216,14 @@ void HipPlatform::contextCreated(ContextImpl& context, const map<string, string>
getPropertyDefaultValue(HipDisablePmeStream()) : properties.find(HipDisablePmeStream())->second);
string deterministicForcesValue = (properties.find(HipDeterministicForces()) == properties.end() ?
getPropertyDefaultValue(HipDeterministicForces()) : properties.find(HipDeterministicForces())->second);
+ string fftBackendValue = (properties.find(HipFFTBackend()) == properties.end() ?
+ getPropertyDefaultValue(HipFFTBackend()) : properties.find(HipFFTBackend())->second);
transform(blockingPropValue.begin(), blockingPropValue.end(), blockingPropValue.begin(), ::tolower);
transform(precisionPropValue.begin(), precisionPropValue.end(), precisionPropValue.begin(), ::tolower);
transform(cpuPmePropValue.begin(), cpuPmePropValue.end(), cpuPmePropValue.begin(), ::tolower);
transform(pmeStreamPropValue.begin(), pmeStreamPropValue.end(), pmeStreamPropValue.begin(), ::tolower);
transform(deterministicForcesValue.begin(), deterministicForcesValue.end(), deterministicForcesValue.begin(), ::tolower);
+ transform(fftBackendValue.begin(), fftBackendValue.end(), fftBackendValue.begin(), ::tolower);
vector<string> pmeKernelName;
pmeKernelName.push_back(CalcPmeReciprocalForceKernel::Name());
if (!supportsKernels(pmeKernelName))
@@ -227,7 +233,7 @@ void HipPlatform::contextCreated(ContextImpl& context, const map<string, string>
if (threadsEnv != NULL)
stringstream(threadsEnv) >> threads;
context.setPlatformData(new PlatformData(&context, context.getSystem(), devicePropValue, blockingPropValue, precisionPropValue, cpuPmePropValue, tempPropValue,
- pmeStreamPropValue, deterministicForcesValue, threads, NULL));
+ pmeStreamPropValue, deterministicForcesValue, fftBackendValue, threads, NULL));
}
void HipPlatform::linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const {
@@ -239,9 +245,10 @@ void HipPlatform::linkedContextCreated(ContextImpl& context, ContextImpl& origin
string tempPropValue = platform.getPropertyValue(originalContext.getOwner(), HipTempDirectory());
string pmeStreamPropValue = platform.getPropertyValue(originalContext.getOwner(), HipDisablePmeStream());
string deterministicForcesValue = platform.getPropertyValue(originalContext.getOwner(), HipDeterministicForces());
+ string fftBackendValue = platform.getPropertyValue(originalContext.getOwner(), HipFFTBackend());
int threads = reinterpret_cast<PlatformData*>(originalContext.getPlatformData())->threads.getNumThreads();
context.setPlatformData(new PlatformData(&context, context.getSystem(), devicePropValue, blockingPropValue, precisionPropValue, cpuPmePropValue, tempPropValue,
- pmeStreamPropValue, deterministicForcesValue, threads, &originalContext));
+ pmeStreamPropValue, deterministicForcesValue, fftBackendValue, threads, &originalContext));
}
void HipPlatform::contextDestroyed(ContextImpl& context) const {
@@ -250,8 +257,8 @@ void HipPlatform::contextDestroyed(ContextImpl& context) const {
}
HipPlatform::PlatformData::PlatformData(ContextImpl* context, const System& system, const string& deviceIndexProperty, const string& blockingProperty, const string& precisionProperty,
- const string& cpuPmeProperty, const string& tempProperty, const string& pmeStreamProperty,
- const string& deterministicForcesProperty, int numThreads, ContextImpl* originalContext) :
+ const string& cpuPmeProperty, const string& tempProperty, const string& pmeStreamProperty,
+ const string& deterministicForcesProperty, const string& fftBackendProperty, int numThreads, ContextImpl* originalContext) :
context(context), removeCM(false), stepCount(0), computeForceCount(0), time(0.0), hasInitializedContexts(false),
threads(numThreads) {
bool blocking = (blockingProperty == "true");
@@ -306,6 +313,7 @@ HipPlatform::PlatformData::PlatformData(ContextImpl* context, const System& syst
propertyValues[HipPlatform::HipTempDirectory()] = tempProperty;
propertyValues[HipPlatform::HipDisablePmeStream()] = disablePmeStream ? "true" : "false";
propertyValues[HipPlatform::HipDeterministicForces()] = deterministicForces ? "true" : "false";
+ propertyValues[HipPlatform::HipFFTBackend()] = fftBackendProperty;
contextEnergy.resize(contexts.size());
// Determine whether peer-to-peer copying is supported, and enable it if so.
diff --git a/platforms/hip/tests/TestHipFFT3D.cpp b/platforms/hip/tests/TestHipFFT3D.cpp
index 170ad1fe0..8aff3112e 100644
--- a/platforms/hip/tests/TestHipFFT3D.cpp
+++ b/platforms/hip/tests/TestHipFFT3D.cpp
@@ -58,7 +58,8 @@ void testTransform(bool realToComplex, int xsize, int ysize, int zsize, double e
system.addParticle(0.0);
HipPlatform::PlatformData platformData(NULL, system, "", "true", platform.getPropertyDefaultValue("HipPrecision"), "false",
platform.getPropertyDefaultValue(HipPlatform::HipTempDirectory()),
- platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false", 1, NULL);
+ platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false",
+ platform.getPropertyDefaultValue(HipPlatform::HipFFTBackend()), 1, NULL);
HipContext& context = *platformData.contexts[0];
context.initialize();
context.setAsCurrent();
diff --git a/platforms/hip/tests/TestHipRandom.cpp b/platforms/hip/tests/TestHipRandom.cpp
index 94212702c..54f869b35 100644
--- a/platforms/hip/tests/TestHipRandom.cpp
+++ b/platforms/hip/tests/TestHipRandom.cpp
@@ -55,7 +55,8 @@ void testGaussian() {
system.addParticle(1.0);
HipPlatform::PlatformData platformData(NULL, system, "", "true", platform.getPropertyDefaultValue("HipPrecision"), "false",
platform.getPropertyDefaultValue(HipPlatform::HipTempDirectory()),
- platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false", 1, NULL);
+ platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false",
+ platform.getPropertyDefaultValue(HipPlatform::HipFFTBackend()), 1, NULL);
HipContext& context = *platformData.contexts[0];
context.initialize();
context.setAsCurrent();
diff --git a/platforms/hip/tests/TestHipSort.cpp b/platforms/hip/tests/TestHipSort.cpp
index fe8b1c56a..03e7ee2c5 100644
--- a/platforms/hip/tests/TestHipSort.cpp
+++ b/platforms/hip/tests/TestHipSort.cpp
@@ -65,7 +65,8 @@ void verifySorting(vector<float> array, bool uniform) {
system.addParticle(0.0);
HipPlatform::PlatformData platformData(NULL, system, "", "true", platform.getPropertyDefaultValue("HipPrecision"), "false",
platform.getPropertyDefaultValue(HipPlatform::HipTempDirectory()),
- platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false", 1, NULL);
+ platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false",
+ platform.getPropertyDefaultValue(HipPlatform::HipFFTBackend()), 1, NULL);
HipContext& context = *platformData.contexts[0];
context.initialize();
context.setAsCurrent();
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment