Commit b1a1c54c authored by one's avatar one
Browse files

Add hipFFT backend for testing

parent ee4ca894
......@@ -354,6 +354,8 @@ def runOneTest(testName, options):
initialSteps = 250
if options.disable_pme_stream:
properties['DisablePmeStream'] = 'true'
if options.platform == 'HIP' and options.fft_backend is not None and 'FFTBackend' in platform.getPropertyNames():
properties['FFTBackend'] = options.fft_backend
if options.opencl_platform is not None and 'OpenCLPlatformIndex' in platform.getPropertyNames():
properties['OpenCLPlatformIndex'] = options.opencl_platform
if (options.precision is not None) and ('Precision' in platform.getPropertyNames()):
......@@ -481,6 +483,7 @@ parser.add_argument('--polarization', default='mutual', dest='polarization', cho
parser.add_argument('--mutual-epsilon', default=1e-5, dest='epsilon', type=float, help='mutual induced epsilon for AMOEBA [default: 1e-5]')
parser.add_argument('--bond-constraints', default='hbonds', dest='bond_constraints', help=f'hbonds: constrain bonds to hydrogen, use 1.5*amu H mass; allbonds: constrain all bonds, use 4*amu H mass, and use larger timestep. This option is ignored for AMOEBA: {BOND_CONSTRAINTS} [default: hbonds]')
parser.add_argument('--disable-pme-stream', default=False, action='store_true', dest='disable_pme_stream', help='disable use of a separate GPU stream for PME')
parser.add_argument('--fft-backend', default=None, dest='fft_backend', help='FFT backend for HIP platform (for example: vkfft or hipfft)')
parser.add_argument('--device', default=None, dest='device', help='device index for CUDA, HIP, or OpenCL')
parser.add_argument('--opencl-platform', default=None, dest='opencl_platform', help='platform index for OpenCL')
parser.add_argument('--precision', default='single', dest='precision', help=f'precision modes for CUDA, HIP, or OpenCL: {PRECISIONS} [default: single]')
......
......@@ -20,6 +20,8 @@ IF(NOT TARGET hiprtc::hiprtc)
)
ENDIF()
FIND_PACKAGE(HIPFFT CONFIG QUIET)
IF(NOT TARGET hiprtc::hiprtc)
add_library(hiprtc::hiprtc SHARED IMPORTED)
set_target_properties(hiprtc::hiprtc PROPERTIES
......@@ -117,6 +119,10 @@ IF(OPENMM_BUILD_SHARED_LIB)
TARGET_LINK_LIBRARIES(${SHARED_TARGET} PUBLIC ${OPENMM_LIBRARY_NAME} hip::host hiprtc::hiprtc)
SET_TARGET_PROPERTIES(${SHARED_TARGET} PROPERTIES COMPILE_FLAGS "${EXTRA_COMPILE_FLAGS} -DOPENMM_COMMON_BUILDING_SHARED_LIBRARY")
SET_TARGET_PROPERTIES(${SHARED_TARGET} PROPERTIES LINK_FLAGS "${EXTRA_LINK_FLAGS}")
IF(HIPFFT_FOUND)
TARGET_LINK_LIBRARIES(${SHARED_TARGET} PUBLIC hip::hipfft)
TARGET_COMPILE_OPTIONS(${SHARED_TARGET} PUBLIC "-DOPENMM_HIP_WITH_HIPFFT")
ENDIF()
INSTALL_TARGETS(/lib/plugins RUNTIME_DIRECTORY /lib/plugins ${SHARED_TARGET})
ENDIF(OPENMM_BUILD_SHARED_LIB)
......@@ -130,6 +136,10 @@ IF(OPENMM_BUILD_STATIC_LIB)
TARGET_LINK_LIBRARIES(${STATIC_TARGET} ${OPENMM_LIBRARY_NAME} hip::host hiprtc::hiprtc)
SET_TARGET_PROPERTIES(${STATIC_TARGET} PROPERTIES COMPILE_FLAGS "${EXTRA_COMPILE_FLAGS} -DOPENMM_COMMON_BUILDING_STATIC_LIBRARY")
SET_TARGET_PROPERTIES(${STATIC_TARGET} PROPERTIES LINK_FLAGS "${EXTRA_LINK_FLAGS}")
IF(HIPFFT_FOUND)
TARGET_LINK_LIBRARIES(${STATIC_TARGET} PUBLIC hip::hipfft)
TARGET_COMPILE_OPTIONS(${STATIC_TARGET} PUBLIC "-DOPENMM_HIP_WITH_HIPFFT")
ENDIF()
INSTALL_TARGETS(/lib/plugins RUNTIME_DIRECTORY /lib/plugins ${STATIC_TARGET})
ENDIF(OPENMM_BUILD_STATIC_LIB)
......
#ifndef __OPENMM_HIPHIPFFT3D_H__
#define __OPENMM_HIPHIPFFT3D_H__
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2009-2026 Stanford University and the Authors. *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "openmm/common/windowsExportCommon.h"
#include "openmm/common/FFT3D.h"
#include "openmm/common/ArrayInterface.h"
#ifdef OPENMM_HIP_WITH_HIPFFT
#if __has_include(<hipfft/hipfft.h>)
#include <hipfft/hipfft.h>
#else
#include <hipfft.h>
#endif
#endif
namespace OpenMM {
class HipContext;
class OPENMM_EXPORT_COMMON HipHipFFT3D : public FFT3DImpl {
public:
HipHipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool realToComplex);
~HipHipFFT3D();
void execFFT(ArrayInterface& in, ArrayInterface& out, bool forward=true) override;
private:
HipContext& context;
bool realToComplex;
bool doublePrecision;
#ifdef OPENMM_HIP_WITH_HIPFFT
hipfftHandle forwardPlan;
hipfftHandle backwardPlan;
#endif
};
} // namespace OpenMM
#endif // __OPENMM_HIPHIPFFT3D_H__
......@@ -111,13 +111,21 @@ public:
static const std::string key = "DeterministicForces";
return key;
}
/**
* This is the name of the parameter for selecting which FFT backend to use.
*/
static const std::string& HipFFTBackend() {
static const std::string key = "FFTBackend";
return key;
}
};
class OPENMM_EXPORT_COMMON HipPlatform::PlatformData {
public:
PlatformData(ContextImpl* context, const System& system, const std::string& deviceIndexProperty, const std::string& blockingProperty, const std::string& precisionProperty,
const std::string& cpuPmeProperty, const std::string& tempProperty,
const std::string& pmeStreamProperty, const std::string& deterministicForcesProperty, int numThreads, ContextImpl* originalContext);
const std::string& pmeStreamProperty, const std::string& deterministicForcesProperty, const std::string& fftBackendProperty,
int numThreads, ContextImpl* originalContext);
~PlatformData();
void initializeContexts(const System& system);
void syncContexts();
......
......@@ -31,6 +31,7 @@
#include "HipBondedUtilities.h"
#include "HipEvent.h"
#include "HipFFT3D.h"
#include "HipHipFFT3D.h"
#include "HipIntegrationUtilities.h"
#include "HipKernels.h"
#include "HipKernelSources.h"
......@@ -80,6 +81,12 @@
using namespace OpenMM;
using namespace std;
static string normalizeFftBackendName(const string& value) {
string result = value;
transform(result.begin(), result.end(), result.begin(), ::tolower);
return result;
}
const int HipContext::ThreadBlockSize = 64;
const int HipContext::TileSize = 32;
bool HipContext::hasInitializedHip = false;
......@@ -712,7 +719,12 @@ ComputeSort HipContext::createSort(ComputeSortImpl::SortTrait* trait, unsigned i
}
FFT3D HipContext::createFFT(int xsize, int ysize, int zsize, bool realToComplex) {
string backend = normalizeFftBackendName(platformData.propertyValues[HipPlatform::HipFFTBackend()]);
if (backend.empty() || backend == "vkfft")
return FFT3D(new HipFFT3D(*this, xsize, ysize, zsize, realToComplex));
if (backend == "hipfft")
return FFT3D(new HipHipFFT3D(*this, xsize, ysize, zsize, realToComplex));
throw OpenMMException("Illegal value for FFTBackend: "+backend);
}
int HipContext::findLegalFFTDimension(int minimum) {
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2009-2026 Stanford University and the Authors. *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "HipHipFFT3D.h"
#include "HipContext.h"
#include "openmm/OpenMMException.h"
using namespace OpenMM;
using namespace std;
HipHipFFT3D::HipHipFFT3D(HipContext& context, int xsize, int ysize, int zsize, bool realToComplex) :
context(context), realToComplex(realToComplex), doublePrecision(context.getUseDoublePrecision())
#ifdef OPENMM_HIP_WITH_HIPFFT
, forwardPlan(0), backwardPlan(0)
#endif
{
#ifndef OPENMM_HIP_WITH_HIPFFT
throw OpenMMException("hipFFT backend requested but OpenMM was built without HIPFFT support");
#else
hipfftType forwardType, backwardType;
if (realToComplex) {
forwardType = doublePrecision ? HIPFFT_D2Z : HIPFFT_R2C;
backwardType = doublePrecision ? HIPFFT_Z2D : HIPFFT_C2R;
}
else {
forwardType = doublePrecision ? HIPFFT_Z2Z : HIPFFT_C2C;
backwardType = forwardType;
}
hipfftResult result = hipfftPlan3d(&forwardPlan, xsize, ysize, zsize, forwardType);
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error initializing hipFFT forward plan: "+context.intToString(result));
result = hipfftPlan3d(&backwardPlan, xsize, ysize, zsize, backwardType);
if (result != HIPFFT_SUCCESS) {
hipfftDestroy(forwardPlan);
throw OpenMMException("Error initializing hipFFT backward plan: "+context.intToString(result));
}
#endif
}
HipHipFFT3D::~HipHipFFT3D() {
#ifdef OPENMM_HIP_WITH_HIPFFT
if (forwardPlan != 0)
hipfftDestroy(forwardPlan);
if (backwardPlan != 0)
hipfftDestroy(backwardPlan);
#endif
}
void HipHipFFT3D::execFFT(ArrayInterface& in, ArrayInterface& out, bool forward) {
#ifndef OPENMM_HIP_WITH_HIPFFT
throw OpenMMException("hipFFT backend requested but OpenMM was built without HIPFFT support");
#else
hipfftHandle plan = forward ? forwardPlan : backwardPlan;
hipfftResult result = hipfftSetStream(plan, context.getCurrentStream());
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error setting hipFFT stream: "+context.intToString(result));
if (realToComplex) {
if (forward) {
if (doublePrecision)
result = hipfftExecD2Z(plan, (double*) context.unwrap(in).getDevicePointer(), (double2*) context.unwrap(out).getDevicePointer());
else
result = hipfftExecR2C(plan, (float*) context.unwrap(in).getDevicePointer(), (float2*) context.unwrap(out).getDevicePointer());
}
else {
if (doublePrecision)
result = hipfftExecZ2D(plan, (double2*) context.unwrap(in).getDevicePointer(), (double*) context.unwrap(out).getDevicePointer());
else
result = hipfftExecC2R(plan, (float2*) context.unwrap(in).getDevicePointer(), (float*) context.unwrap(out).getDevicePointer());
}
}
else {
if (doublePrecision)
result = hipfftExecZ2Z(plan, (double2*) context.unwrap(in).getDevicePointer(), (double2*) context.unwrap(out).getDevicePointer(), forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD);
else
result = hipfftExecC2C(plan, (float2*) context.unwrap(in).getDevicePointer(), (float2*) context.unwrap(out).getDevicePointer(), forward ? HIPFFT_FORWARD : HIPFFT_BACKWARD);
}
if (result != HIPFFT_SUCCESS)
throw OpenMMException("Error executing hipFFT: "+context.intToString(result));
#endif
}
......@@ -69,6 +69,7 @@ HipPlatform::HipPlatform() {
deprecatedPropertyReplacements["HipTempDirectory"] = HipTempDirectory();
deprecatedPropertyReplacements["HipDisablePmeStream"] = HipDisablePmeStream();
deprecatedPropertyReplacements["HipDeterministicForces"] = HipDeterministicForces();
deprecatedPropertyReplacements["HipFFTBackend"] = HipFFTBackend();
HipKernelFactory* factory = new HipKernelFactory();
registerKernelFactory(CalcForcesAndEnergyKernel::Name(), factory);
registerKernelFactory(UpdateStateDataKernel::Name(), factory);
......@@ -122,6 +123,7 @@ HipPlatform::HipPlatform() {
platformProperties.push_back(HipTempDirectory());
platformProperties.push_back(HipDisablePmeStream());
platformProperties.push_back(HipDeterministicForces());
platformProperties.push_back(HipFFTBackend());
setPropertyDefaultValue(HipDeviceIndex(), "");
setPropertyDefaultValue(HipDeviceName(), "");
setPropertyDefaultValue(HipUseBlockingSync(), "true");
......@@ -129,6 +131,7 @@ HipPlatform::HipPlatform() {
setPropertyDefaultValue(HipUseCpuPme(), "false");
setPropertyDefaultValue(HipDisablePmeStream(), "false");
setPropertyDefaultValue(HipDeterministicForces(), "false");
setPropertyDefaultValue(HipFFTBackend(), "vkfft");
#ifdef _MSC_VER
setPropertyDefaultValue(HipTempDirectory(), string(getenv("TEMP")));
#else
......@@ -213,11 +216,14 @@ void HipPlatform::contextCreated(ContextImpl& context, const map<string, string>
getPropertyDefaultValue(HipDisablePmeStream()) : properties.find(HipDisablePmeStream())->second);
string deterministicForcesValue = (properties.find(HipDeterministicForces()) == properties.end() ?
getPropertyDefaultValue(HipDeterministicForces()) : properties.find(HipDeterministicForces())->second);
string fftBackendValue = (properties.find(HipFFTBackend()) == properties.end() ?
getPropertyDefaultValue(HipFFTBackend()) : properties.find(HipFFTBackend())->second);
transform(blockingPropValue.begin(), blockingPropValue.end(), blockingPropValue.begin(), ::tolower);
transform(precisionPropValue.begin(), precisionPropValue.end(), precisionPropValue.begin(), ::tolower);
transform(cpuPmePropValue.begin(), cpuPmePropValue.end(), cpuPmePropValue.begin(), ::tolower);
transform(pmeStreamPropValue.begin(), pmeStreamPropValue.end(), pmeStreamPropValue.begin(), ::tolower);
transform(deterministicForcesValue.begin(), deterministicForcesValue.end(), deterministicForcesValue.begin(), ::tolower);
transform(fftBackendValue.begin(), fftBackendValue.end(), fftBackendValue.begin(), ::tolower);
vector<string> pmeKernelName;
pmeKernelName.push_back(CalcPmeReciprocalForceKernel::Name());
if (!supportsKernels(pmeKernelName))
......@@ -227,7 +233,7 @@ void HipPlatform::contextCreated(ContextImpl& context, const map<string, string>
if (threadsEnv != NULL)
stringstream(threadsEnv) >> threads;
context.setPlatformData(new PlatformData(&context, context.getSystem(), devicePropValue, blockingPropValue, precisionPropValue, cpuPmePropValue, tempPropValue,
pmeStreamPropValue, deterministicForcesValue, threads, NULL));
pmeStreamPropValue, deterministicForcesValue, fftBackendValue, threads, NULL));
}
void HipPlatform::linkedContextCreated(ContextImpl& context, ContextImpl& originalContext) const {
......@@ -239,9 +245,10 @@ void HipPlatform::linkedContextCreated(ContextImpl& context, ContextImpl& origin
string tempPropValue = platform.getPropertyValue(originalContext.getOwner(), HipTempDirectory());
string pmeStreamPropValue = platform.getPropertyValue(originalContext.getOwner(), HipDisablePmeStream());
string deterministicForcesValue = platform.getPropertyValue(originalContext.getOwner(), HipDeterministicForces());
string fftBackendValue = platform.getPropertyValue(originalContext.getOwner(), HipFFTBackend());
int threads = reinterpret_cast<PlatformData*>(originalContext.getPlatformData())->threads.getNumThreads();
context.setPlatformData(new PlatformData(&context, context.getSystem(), devicePropValue, blockingPropValue, precisionPropValue, cpuPmePropValue, tempPropValue,
pmeStreamPropValue, deterministicForcesValue, threads, &originalContext));
pmeStreamPropValue, deterministicForcesValue, fftBackendValue, threads, &originalContext));
}
void HipPlatform::contextDestroyed(ContextImpl& context) const {
......@@ -251,7 +258,7 @@ void HipPlatform::contextDestroyed(ContextImpl& context) const {
HipPlatform::PlatformData::PlatformData(ContextImpl* context, const System& system, const string& deviceIndexProperty, const string& blockingProperty, const string& precisionProperty,
const string& cpuPmeProperty, const string& tempProperty, const string& pmeStreamProperty,
const string& deterministicForcesProperty, int numThreads, ContextImpl* originalContext) :
const string& deterministicForcesProperty, const string& fftBackendProperty, int numThreads, ContextImpl* originalContext) :
context(context), removeCM(false), stepCount(0), computeForceCount(0), time(0.0), hasInitializedContexts(false),
threads(numThreads) {
bool blocking = (blockingProperty == "true");
......@@ -306,6 +313,7 @@ HipPlatform::PlatformData::PlatformData(ContextImpl* context, const System& syst
propertyValues[HipPlatform::HipTempDirectory()] = tempProperty;
propertyValues[HipPlatform::HipDisablePmeStream()] = disablePmeStream ? "true" : "false";
propertyValues[HipPlatform::HipDeterministicForces()] = deterministicForces ? "true" : "false";
propertyValues[HipPlatform::HipFFTBackend()] = fftBackendProperty;
contextEnergy.resize(contexts.size());
// Determine whether peer-to-peer copying is supported, and enable it if so.
......
......@@ -58,7 +58,8 @@ void testTransform(bool realToComplex, int xsize, int ysize, int zsize, double e
system.addParticle(0.0);
HipPlatform::PlatformData platformData(NULL, system, "", "true", platform.getPropertyDefaultValue("HipPrecision"), "false",
platform.getPropertyDefaultValue(HipPlatform::HipTempDirectory()),
platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false", 1, NULL);
platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false",
platform.getPropertyDefaultValue(HipPlatform::HipFFTBackend()), 1, NULL);
HipContext& context = *platformData.contexts[0];
context.initialize();
context.setAsCurrent();
......
......@@ -55,7 +55,8 @@ void testGaussian() {
system.addParticle(1.0);
HipPlatform::PlatformData platformData(NULL, system, "", "true", platform.getPropertyDefaultValue("HipPrecision"), "false",
platform.getPropertyDefaultValue(HipPlatform::HipTempDirectory()),
platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false", 1, NULL);
platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false",
platform.getPropertyDefaultValue(HipPlatform::HipFFTBackend()), 1, NULL);
HipContext& context = *platformData.contexts[0];
context.initialize();
context.setAsCurrent();
......
......@@ -65,7 +65,8 @@ void verifySorting(vector<float> array, bool uniform) {
system.addParticle(0.0);
HipPlatform::PlatformData platformData(NULL, system, "", "true", platform.getPropertyDefaultValue("HipPrecision"), "false",
platform.getPropertyDefaultValue(HipPlatform::HipTempDirectory()),
platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false", 1, NULL);
platform.getPropertyDefaultValue(HipPlatform::HipDisablePmeStream()), "false",
platform.getPropertyDefaultValue(HipPlatform::HipFFTBackend()), 1, NULL);
HipContext& context = *platformData.contexts[0];
context.initialize();
context.setAsCurrent();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment