Unverified Commit dd320bcf authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Unified interface for queues (#4913)

* Unified interface for queues

* Simplified stream handling in CudaFFT3D

* HIP implementation of ComputeQueue
parent baf7942c
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#include "openmm/common/ComputeEvent.h" #include "openmm/common/ComputeEvent.h"
#include "openmm/common/ComputeForceInfo.h" #include "openmm/common/ComputeForceInfo.h"
#include "openmm/common/ComputeProgram.h" #include "openmm/common/ComputeProgram.h"
#include "openmm/common/ComputeQueue.h"
#include "openmm/common/ComputeVectorTypes.h" #include "openmm/common/ComputeVectorTypes.h"
#include "openmm/common/FFT3D.h" #include "openmm/common/FFT3D.h"
#include "openmm/common/IntegrationUtilities.h" #include "openmm/common/IntegrationUtilities.h"
...@@ -143,6 +144,22 @@ public: ...@@ -143,6 +144,22 @@ public:
* multiple devices. * multiple devices.
*/ */
virtual double& getEnergyWorkspace() = 0; virtual double& getEnergyWorkspace() = 0;
/**
* Create a new ComputeQueue for use with this context.
*/
virtual ComputeQueue createQueue() = 0;
/**
* Get the ComputeQueue currently being used for execution.
*/
ComputeQueue getCurrentQueue();
/**
* Set the ComputeQueue to use for execution.
*/
void setCurrentQueue(ComputeQueue queue);
/**
* Reset the context to using the default queue for execution.
*/
void restoreDefaultQueue();
/** /**
* Construct an uninitialized array of the appropriate class for this platform. The returned * Construct an uninitialized array of the appropriate class for this platform. The returned
* value should be created on the heap with the "new" operator. * value should be created on the heap with the "new" operator.
...@@ -560,6 +577,7 @@ protected: ...@@ -560,6 +577,7 @@ protected:
int numAtoms, paddedNumAtoms, computeForceCount, stepsSinceReorder; int numAtoms, paddedNumAtoms, computeForceCount, stepsSinceReorder;
long long stepCount; long long stepCount;
bool forceNextReorder, atomsWereReordered, forcesValid; bool forceNextReorder, atomsWereReordered, forcesValid;
ComputeQueue defaultQueue, currentQueue;
std::vector<ComputeForceInfo*> forces; std::vector<ComputeForceInfo*> forces;
std::vector<Molecule> molecules; std::vector<Molecule> molecules;
std::vector<MoleculeGroup> moleculeGroups; std::vector<MoleculeGroup> moleculeGroups;
......
#ifndef OPENMM_COMPUTEQUEUE_H_
#define OPENMM_COMPUTEQUEUE_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "openmm/common/windowsExportCommon.h"
#include <memory>
namespace OpenMM {
/**
* This abstract class represents a queue within which kernels can be executed. Call
* createQueue() on a ComputeContext to create an instance of a platform-specific
* subclass. You can then pass it to the ComputeContext's setQueue() method to cause
* kernels to be launched on it.
*
* Instead of referring to this class directly, it is best to use ComputeQueue, which is
* a typedef for a shared_ptr to a ComputeQueueImpl. This allows you to treat it as having
* value semantics, and frees you from having to manage memory.
*/
class OPENMM_EXPORT_COMMON ComputeQueueImpl {
public:
virtual ~ComputeQueueImpl() {
}
};
typedef std::shared_ptr<ComputeQueueImpl> ComputeQueue;
} // namespace OpenMM
#endif /*OPENMM_COMPUTEQUEUE_H_*/
...@@ -52,6 +52,18 @@ ComputeContext::ComputeContext(const System& system) : system(system), time(0.0) ...@@ -52,6 +52,18 @@ ComputeContext::ComputeContext(const System& system) : system(system), time(0.0)
ComputeContext::~ComputeContext() { ComputeContext::~ComputeContext() {
} }
ComputeQueue ComputeContext::getCurrentQueue() {
return currentQueue;
}
void ComputeContext::setCurrentQueue(ComputeQueue queue) {
currentQueue = queue;
}
void ComputeContext::restoreDefaultQueue() {
currentQueue = defaultQueue;
}
void ComputeContext::addForce(ComputeForceInfo* force) { void ComputeContext::addForce(ComputeForceInfo* force) {
forces.push_back(force); forces.push_back(force);
} }
......
...@@ -46,6 +46,7 @@ ...@@ -46,6 +46,7 @@
#include "CudaIntegrationUtilities.h" #include "CudaIntegrationUtilities.h"
#include "CudaNonbondedUtilities.h" #include "CudaNonbondedUtilities.h"
#include "CudaPlatform.h" #include "CudaPlatform.h"
#include "CudaQueue.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/common/ComputeContext.h" #include "openmm/common/ComputeContext.h"
#include "openmm/Kernel.h" #include "openmm/Kernel.h"
...@@ -159,17 +160,13 @@ public: ...@@ -159,17 +160,13 @@ public:
*/ */
double& getEnergyWorkspace(); double& getEnergyWorkspace();
/** /**
* Get the stream currently being used for execution. * Create a new ComputeQueue for use with this context.
*/
CUstream getCurrentStream();
/**
* Set the stream to use for execution.
*/ */
void setCurrentStream(CUstream stream); ComputeQueue createQueue();
/** /**
* Reset the context to using the default stream for execution. * Get the stream currently being used for execution.
*/ */
void restoreDefaultStream(); CUstream getCurrentStream();
/** /**
* Construct an uninitialized array of the appropriate class for this platform. The returned * Construct an uninitialized array of the appropriate class for this platform. The returned
* value should be created on the heap with the "new" operator. * value should be created on the heap with the "new" operator.
...@@ -587,7 +584,6 @@ private: ...@@ -587,7 +584,6 @@ private:
std::map<std::string, std::string> compilationDefines; std::map<std::string, std::string> compilationDefines;
CUcontext context; CUcontext context;
CUdevice device; CUdevice device;
CUstream currentStream;
CUfunction clearBufferKernel; CUfunction clearBufferKernel;
CUfunction clearTwoBuffersKernel; CUfunction clearTwoBuffersKernel;
CUfunction clearThreeBuffersKernel; CUfunction clearThreeBuffersKernel;
......
...@@ -63,10 +63,6 @@ public: ...@@ -63,10 +63,6 @@ public:
*/ */
CudaFFT3D(CudaContext& context, int xsize, int ysize, int zsize, bool realToComplex=false); CudaFFT3D(CudaContext& context, int xsize, int ysize, int zsize, bool realToComplex=false);
~CudaFFT3D(); ~CudaFFT3D();
/**
* Set the stream to perform the FFT on.
*/
void setStream(CUstream stream);
/** /**
* Perform a Fourier transform. The transform cannot be done in-place: the input and output * Perform a Fourier transform. The transform cannot be done in-place: the input and output
* arrays must be different. Also, the input array is used as workspace, so its contents * arrays must be different. Also, the input array is used as workspace, so its contents
......
...@@ -185,7 +185,7 @@ private: ...@@ -185,7 +185,7 @@ private:
CudaSort* sort; CudaSort* sort;
Kernel cpuPme; Kernel cpuPme;
PmeIO* pmeio; PmeIO* pmeio;
CUstream pmeStream; ComputeQueue pmeQueue;
CUevent pmeSyncEvent, paramsSyncEvent; CUevent pmeSyncEvent, paramsSyncEvent;
CudaFFT3D* fft; CudaFFT3D* fft;
CudaFFT3D* dispersionFft; CudaFFT3D* dispersionFft;
......
#ifndef OPENMM_CUDAQUEUE_H_
#define OPENMM_CUDAQUEUE_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "openmm/common/ComputeQueue.h"
#include <cuda.h>
namespace OpenMM {
/**
* This is the CUDA implementation of the ComputeQueue interface. It wraps a CUstream.
*/
class CudaQueue : public ComputeQueueImpl {
public:
/**
* Create a CudaQueue that wraps an existing CUstream.
*/
CudaQueue(CUstream stream);
/**
* Create a CudaQueue that create a new CUstream.
*/
CudaQueue();
~CudaQueue();
/**
* Get the CUstream.
*/
CUstream getStream() {
return stream;
}
private:
CUstream stream;
bool initialized;
};
} // namespace OpenMM
#endif /*OPENMM_CUDAQUEUE_H_*/
...@@ -84,7 +84,7 @@ const int CudaContext::TileSize = sizeof(tileflags)*8; ...@@ -84,7 +84,7 @@ const int CudaContext::TileSize = sizeof(tileflags)*8;
bool CudaContext::hasInitializedCuda = false; bool CudaContext::hasInitializedCuda = false;
CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlockingSync, const string& precision, const string& tempDir, CudaPlatform::PlatformData& platformData, CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlockingSync, const string& precision, const string& tempDir, CudaPlatform::PlatformData& platformData,
CudaContext* originalContext) : ComputeContext(system), currentStream(0), platformData(platformData), contextIsValid(false), hasAssignedPosqCharges(false), CudaContext* originalContext) : ComputeContext(system), platformData(platformData), contextIsValid(false), hasAssignedPosqCharges(false),
pinnedBuffer(NULL), integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL), useBlockingSync(useBlockingSync) { pinnedBuffer(NULL), integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL), useBlockingSync(useBlockingSync) {
int cudaDriverVersion; int cudaDriverVersion;
cuDriverGetVersion(&cudaDriverVersion); cuDriverGetVersion(&cudaDriverVersion);
...@@ -200,6 +200,8 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking ...@@ -200,6 +200,8 @@ CudaContext::CudaContext(const System& system, int deviceIndex, bool useBlocking
CHECK_RESULT(cuCtxEnablePeerAccess(platformData.contexts[0]->getContext(), 0)); CHECK_RESULT(cuCtxEnablePeerAccess(platformData.contexts[0]->getContext(), 0));
} }
} }
defaultQueue = shared_ptr<ComputeQueueImpl>(new CudaQueue(0));
currentQueue = defaultQueue;
numAtoms = system.getNumParticles(); numAtoms = system.getNumParticles();
paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize); paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize);
numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize; numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize;
...@@ -649,16 +651,12 @@ double& CudaContext::getEnergyWorkspace() { ...@@ -649,16 +651,12 @@ double& CudaContext::getEnergyWorkspace() {
return platformData.contextEnergy[contextIndex]; return platformData.contextEnergy[contextIndex];
} }
CUstream CudaContext::getCurrentStream() { ComputeQueue CudaContext::createQueue() {
return currentStream; return shared_ptr<ComputeQueueImpl>(new CudaQueue());
}
void CudaContext::setCurrentStream(CUstream stream) {
currentStream = stream;
} }
void CudaContext::restoreDefaultStream() { CUstream CudaContext::getCurrentStream() {
setCurrentStream(0); return dynamic_cast<CudaQueue*>(currentQueue.get())->getStream();
} }
CudaArray* CudaContext::createArray() { CudaArray* CudaContext::createArray() {
...@@ -697,7 +695,7 @@ void CudaContext::executeKernel(CUfunction kernel, void** arguments, int threads ...@@ -697,7 +695,7 @@ void CudaContext::executeKernel(CUfunction kernel, void** arguments, int threads
if (blockSize == -1) if (blockSize == -1)
blockSize = ThreadBlockSize; blockSize = ThreadBlockSize;
int gridSize = std::min((threads+blockSize-1)/blockSize, numThreadBlocks); int gridSize = std::min((threads+blockSize-1)/blockSize, numThreadBlocks);
CUresult result = cuLaunchKernel(kernel, gridSize, 1, 1, blockSize, 1, 1, sharedSize, currentStream, arguments, NULL); CUresult result = cuLaunchKernel(kernel, gridSize, 1, 1, blockSize, 1, 1, sharedSize, getCurrentStream(), arguments, NULL);
if (result != CUDA_SUCCESS) { if (result != CUDA_SUCCESS) {
stringstream str; stringstream str;
str<<"Error invoking kernel: "<<getErrorString(result)<<" ("<<result<<")"; str<<"Error invoking kernel: "<<getErrorString(result)<<" ("<<result<<")";
......
...@@ -64,16 +64,12 @@ CudaFFT3D::~CudaFFT3D() { ...@@ -64,16 +64,12 @@ CudaFFT3D::~CudaFFT3D() {
} }
} }
void CudaFFT3D::setStream(CUstream stream) {
cufftSetStream(fftForward, stream);
cufftSetStream(fftBackward, stream);
}
void CudaFFT3D::execFFT(ArrayInterface& in, ArrayInterface& out, bool forward) { void CudaFFT3D::execFFT(ArrayInterface& in, ArrayInterface& out, bool forward) {
CUdeviceptr in2 = context.unwrap(in).getDevicePointer(); CUdeviceptr in2 = context.unwrap(in).getDevicePointer();
CUdeviceptr out2 = context.unwrap(out).getDevicePointer(); CUdeviceptr out2 = context.unwrap(out).getDevicePointer();
cufftResult result; cufftResult result;
if (forward) { if (forward) {
cufftSetStream(fftForward, context.getCurrentStream());
if (realToComplex) { if (realToComplex) {
if (context.getUseDoublePrecision()) if (context.getUseDoublePrecision())
result = cufftExecD2Z(fftForward, (double*) in2, (double2*) out2); result = cufftExecD2Z(fftForward, (double*) in2, (double2*) out2);
...@@ -88,6 +84,7 @@ void CudaFFT3D::execFFT(ArrayInterface& in, ArrayInterface& out, bool forward) { ...@@ -88,6 +84,7 @@ void CudaFFT3D::execFFT(ArrayInterface& in, ArrayInterface& out, bool forward) {
} }
} }
else { else {
cufftSetStream(fftBackward, context.getCurrentStream());
if (realToComplex) { if (realToComplex) {
if (context.getUseDoublePrecision()) if (context.getUseDoublePrecision())
result = cufftExecZ2D(fftBackward, (double2*) in2, (double*) out2); result = cufftExecZ2D(fftBackward, (double2*) in2, (double*) out2);
......
...@@ -207,17 +207,17 @@ private: ...@@ -207,17 +207,17 @@ private:
class CudaCalcNonbondedForceKernel::SyncStreamPreComputation : public CudaContext::ForcePreComputation { class CudaCalcNonbondedForceKernel::SyncStreamPreComputation : public CudaContext::ForcePreComputation {
public: public:
SyncStreamPreComputation(CudaContext& cu, CUstream stream, CUevent event, int forceGroup) : cu(cu), stream(stream), event(event), forceGroup(forceGroup) { SyncStreamPreComputation(CudaContext& cu, ComputeQueue queue, CUevent event, int forceGroup) : cu(cu), queue(queue), event(event), forceGroup(forceGroup) {
} }
void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) { void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<forceGroup)) != 0) { if ((groups&(1<<forceGroup)) != 0) {
cuEventRecord(event, cu.getCurrentStream()); cuEventRecord(event, cu.getCurrentStream());
cuStreamWaitEvent(stream, event, 0); cuStreamWaitEvent(dynamic_cast<CudaQueue*>(queue.get())->getStream(), event, 0);
} }
} }
private: private:
CudaContext& cu; CudaContext& cu;
CUstream stream; ComputeQueue queue;
CUevent event; CUevent event;
int forceGroup; int forceGroup;
}; };
...@@ -256,8 +256,6 @@ CudaCalcNonbondedForceKernel::~CudaCalcNonbondedForceKernel() { ...@@ -256,8 +256,6 @@ CudaCalcNonbondedForceKernel::~CudaCalcNonbondedForceKernel() {
delete dispersionFft; delete dispersionFft;
if (pmeio != NULL) if (pmeio != NULL)
delete pmeio; delete pmeio;
if (hasInitializedFFT && usePmeStream)
cuStreamDestroy(pmeStream);
} }
void CudaCalcNonbondedForceKernel::initialize(const System& system, const NonbondedForce& force) { void CudaCalcNonbondedForceKernel::initialize(const System& system, const NonbondedForce& force) {
...@@ -544,16 +542,13 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon ...@@ -544,16 +542,13 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
// Prepare for doing PME on its own stream. // Prepare for doing PME on its own stream.
if (usePmeStream) { if (usePmeStream) {
cuStreamCreate(&pmeStream, CU_STREAM_NON_BLOCKING); pmeQueue = cu.createQueue();
fft->setStream(pmeStream);
if (doLJPME)
dispersionFft->setStream(pmeStream);
CHECK_RESULT(cuEventCreate(&pmeSyncEvent, cu.getEventFlags()), "Error creating event for NonbondedForce"); CHECK_RESULT(cuEventCreate(&pmeSyncEvent, cu.getEventFlags()), "Error creating event for NonbondedForce");
CHECK_RESULT(cuEventCreate(&paramsSyncEvent, cu.getEventFlags()), "Error creating event for NonbondedForce"); CHECK_RESULT(cuEventCreate(&paramsSyncEvent, cu.getEventFlags()), "Error creating event for NonbondedForce");
int recipForceGroup = force.getReciprocalSpaceForceGroup(); int recipForceGroup = force.getReciprocalSpaceForceGroup();
if (recipForceGroup < 0) if (recipForceGroup < 0)
recipForceGroup = force.getForceGroup(); recipForceGroup = force.getForceGroup();
cu.addPreComputation(new SyncStreamPreComputation(cu, pmeStream, pmeSyncEvent, recipForceGroup)); cu.addPreComputation(new SyncStreamPreComputation(cu, pmeQueue, pmeSyncEvent, recipForceGroup));
cu.addPostComputation(new SyncStreamPostComputation(cu, pmeSyncEvent, cu.getKernel(module, "addEnergy"), pmeEnergyBuffer, recipForceGroup)); cu.addPostComputation(new SyncStreamPostComputation(cu, pmeSyncEvent, cu.getKernel(module, "addEnergy"), pmeEnergyBuffer, recipForceGroup));
} }
hasInitializedFFT = true; hasInitializedFFT = true;
...@@ -857,7 +852,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF ...@@ -857,7 +852,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF
} }
if (usePmeStream) { if (usePmeStream) {
cuEventRecord(paramsSyncEvent, cu.getCurrentStream()); cuEventRecord(paramsSyncEvent, cu.getCurrentStream());
cuStreamWaitEvent(pmeStream, paramsSyncEvent, 0); cuStreamWaitEvent(dynamic_cast<CudaQueue*>(pmeQueue.get())->getStream(), paramsSyncEvent, 0);
} }
if (hasOffsets) { if (hasOffsets) {
// The Ewald self energy was computed in the kernel. // The Ewald self energy was computed in the kernel.
...@@ -893,7 +888,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF ...@@ -893,7 +888,7 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF
} }
if (pmeGrid1.isInitialized() && includeReciprocal) { if (pmeGrid1.isInitialized() && includeReciprocal) {
if (usePmeStream) if (usePmeStream)
cu.setCurrentStream(pmeStream); cu.setCurrentQueue(pmeQueue);
// Invert the periodic box vectors. // Invert the periodic box vectors.
...@@ -1015,8 +1010,8 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF ...@@ -1015,8 +1010,8 @@ double CudaCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeF
cu.executeKernel(pmeInterpolateDispersionForceKernel, interpolateArgs, cu.getNumAtoms(), 128); cu.executeKernel(pmeInterpolateDispersionForceKernel, interpolateArgs, cu.getNumAtoms(), 128);
} }
if (usePmeStream) { if (usePmeStream) {
cuEventRecord(pmeSyncEvent, pmeStream); cuEventRecord(pmeSyncEvent, dynamic_cast<CudaQueue*>(pmeQueue.get())->getStream());
cu.restoreDefaultStream(); cu.restoreDefaultQueue();
} }
} }
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "CudaQueue.h"
#include "CudaContext.h"
#include "openmm/OpenMMException.h"
using namespace OpenMM;
CudaQueue::CudaQueue(CUstream stream) : stream(stream), initialized(false) {
}
CudaQueue::CudaQueue() : initialized(false) {
CUresult result = cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING);
if (result != CUDA_SUCCESS)
throw OpenMMException("Error creating CUDA stream: "+CudaContext::getErrorString(result));
initialized = true;
}
CudaQueue::~CudaQueue() {
if (initialized)
cuStreamDestroy(stream);
}
...@@ -162,17 +162,13 @@ public: ...@@ -162,17 +162,13 @@ public:
*/ */
double& getEnergyWorkspace(); double& getEnergyWorkspace();
/** /**
* Get the stream currently being used for execution. * Create a new ComputeQueue for use with this context.
*/
hipStream_t getCurrentStream();
/**
* Set the stream to use for execution.
*/ */
void setCurrentStream(hipStream_t stream); ComputeQueue createQueue();
/** /**
* Reset the context to using the default stream for execution. * Get the stream currently being used for execution.
*/ */
void restoreDefaultStream(); hipStream_t getCurrentStream();
/** /**
* Construct an uninitialized array of the appropriate class for this platform. The returned * Construct an uninitialized array of the appropriate class for this platform. The returned
* value should be created on the heap with the "new" operator. * value should be created on the heap with the "new" operator.
...@@ -632,8 +628,6 @@ private: ...@@ -632,8 +628,6 @@ private:
std::map<std::string, std::string> compilationDefines; std::map<std::string, std::string> compilationDefines;
std::vector<hipModule_t> loadedModules; std::vector<hipModule_t> loadedModules;
hipDevice_t device; hipDevice_t device;
hipStream_t currentStream;
hipStream_t defaultStream;
hipFunction_t clearBufferKernel; hipFunction_t clearBufferKernel;
hipFunction_t clearTwoBuffersKernel; hipFunction_t clearTwoBuffersKernel;
hipFunction_t clearThreeBuffersKernel; hipFunction_t clearThreeBuffersKernel;
......
...@@ -186,7 +186,7 @@ private: ...@@ -186,7 +186,7 @@ private:
HipSort* sort; HipSort* sort;
Kernel cpuPme; Kernel cpuPme;
PmeIO* pmeio; PmeIO* pmeio;
hipStream_t pmeStream; ComputeQueue pmeQueue;
hipEvent_t pmeSyncEvent, paramsSyncEvent; hipEvent_t pmeSyncEvent, paramsSyncEvent;
HipFFT3D* fft; HipFFT3D* fft;
HipFFT3D* dispersionFft; HipFFT3D* dispersionFft;
......
#ifndef OPENMM_HIPQUEUE_H_
#define OPENMM_HIPQUEUE_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "openmm/common/ComputeQueue.h"
#include <hip/hip_runtime.h>
namespace OpenMM {
/**
* This is the HIP implementation of the ComputeQueue interface. It wraps a hipStream_t.
*/
class HipQueue : public ComputeQueueImpl {
public:
/**
* Create a HipQueue that wraps an existing hipStream_t.
*/
HipQueue(hipStream_t stream);
/**
* Create a HipQueue that create a new hipStream_t.
*/
HipQueue();
~HipQueue();
/**
* Get the CUstream.
*/
hipStream_t getStream() {
return stream;
}
private:
hipStream_t stream;
bool initialized;
};
} // namespace OpenMM
#endif /*OPENMM_HIPQUEUE_H_*/
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2024 Stanford University and the Authors. * * Portions copyright (c) 2009-2025 Stanford University and the Authors. *
* Portions copyright (c) 2020-2023 Advanced Micro Devices, Inc. * * Portions copyright (c) 2020-2023 Advanced Micro Devices, Inc. *
* Authors: Peter Eastman, Nicholas Curtis * * Authors: Peter Eastman, Nicholas Curtis *
* Contributors: * * Contributors: *
...@@ -32,12 +32,13 @@ ...@@ -32,12 +32,13 @@
#include "HipArray.h" #include "HipArray.h"
#include "HipBondedUtilities.h" #include "HipBondedUtilities.h"
#include "HipEvent.h" #include "HipEvent.h"
#include "HipFFT3D.h"
#include "HipIntegrationUtilities.h" #include "HipIntegrationUtilities.h"
#include "HipKernels.h" #include "HipKernels.h"
#include "HipKernelSources.h" #include "HipKernelSources.h"
#include "HipNonbondedUtilities.h" #include "HipNonbondedUtilities.h"
#include "HipProgram.h" #include "HipProgram.h"
#include "HipFFT3D.h" #include "HipQueue.h"
#include "openmm/common/ComputeArray.h" #include "openmm/common/ComputeArray.h"
#include "openmm/common/ContextSelector.h" #include "openmm/common/ContextSelector.h"
#include "SHA1.h" #include "SHA1.h"
...@@ -86,7 +87,7 @@ bool HipContext::hasInitializedHip = false; ...@@ -86,7 +87,7 @@ bool HipContext::hasInitializedHip = false;
HipContext::HipContext(const System& system, int deviceIndex, bool useBlockingSync, const string& precision, const string& tempDir, HipPlatform::PlatformData& platformData, HipContext::HipContext(const System& system, int deviceIndex, bool useBlockingSync, const string& precision, const string& tempDir, HipPlatform::PlatformData& platformData,
HipContext* originalContext) : ComputeContext(system), currentStream(0), defaultStream(0), platformData(platformData), contextIsValid(false), hasAssignedPosqCharges(false), HipContext* originalContext) : ComputeContext(system), platformData(platformData), contextIsValid(false), hasAssignedPosqCharges(false),
pinnedBuffer(NULL), integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL), pinnedBuffer(NULL), integration(NULL), expression(NULL), bonded(NULL), nonbonded(NULL),
useBlockingSync(useBlockingSync), supportsHardwareFloatGlobalAtomicAdd(false) { useBlockingSync(useBlockingSync), supportsHardwareFloatGlobalAtomicAdd(false) {
if (!hasInitializedHip) { if (!hasInitializedHip) {
...@@ -149,15 +150,15 @@ HipContext::HipContext(const System& system, int deviceIndex, bool useBlockingSy ...@@ -149,15 +150,15 @@ HipContext::HipContext(const System& system, int deviceIndex, bool useBlockingSy
else else
throw OpenMMException("No compatible HIP device is available"); throw OpenMMException("No compatible HIP device is available");
} }
CHECK_RESULT(hipStreamCreateWithFlags(&defaultStream, hipStreamNonBlocking)); defaultQueue = shared_ptr<ComputeQueueImpl>(new HipQueue());
} }
else { else {
isLinkedContext = true; isLinkedContext = true;
this->deviceIndex = originalContext->deviceIndex; this->deviceIndex = originalContext->deviceIndex;
this->device = originalContext->device; this->device = originalContext->device;
defaultStream = originalContext->defaultStream; defaultQueue = originalContext->defaultQueue;
} }
currentStream = defaultStream; currentQueue = defaultQueue;
hipDeviceProp_t props; hipDeviceProp_t props;
CHECK_RESULT(hipGetDeviceProperties(&props, device)); CHECK_RESULT(hipGetDeviceProperties(&props, device));
...@@ -373,8 +374,6 @@ HipContext::~HipContext() { ...@@ -373,8 +374,6 @@ HipContext::~HipContext() {
delete nonbonded; delete nonbonded;
for (auto module : loadedModules) for (auto module : loadedModules)
hipModuleUnload(module); hipModuleUnload(module);
if (!isLinkedContext)
hipStreamDestroy(defaultStream);
popAsCurrent(); popAsCurrent();
contextIsValid = false; contextIsValid = false;
} }
...@@ -676,16 +675,12 @@ double& HipContext::getEnergyWorkspace() { ...@@ -676,16 +675,12 @@ double& HipContext::getEnergyWorkspace() {
return platformData.contextEnergy[contextIndex]; return platformData.contextEnergy[contextIndex];
} }
hipStream_t HipContext::getCurrentStream() { ComputeQueue HipContext::createQueue() {
return currentStream; return shared_ptr<ComputeQueueImpl>(new HipQueue());
}
void HipContext::setCurrentStream(hipStream_t stream) {
currentStream = stream;
} }
void HipContext::restoreDefaultStream() { hipStream_t HipContext::getCurrentStream() {
currentStream = defaultStream; return dynamic_cast<HipQueue*>(currentQueue.get())->getStream();
} }
HipArray* HipContext::createArray() { HipArray* HipContext::createArray() {
...@@ -729,7 +724,7 @@ void HipContext::executeKernel(hipFunction_t kernel, void** arguments, int threa ...@@ -729,7 +724,7 @@ void HipContext::executeKernel(hipFunction_t kernel, void** arguments, int threa
if (blockSize == -1) if (blockSize == -1)
blockSize = ThreadBlockSize; blockSize = ThreadBlockSize;
int gridSize = std::min((threads+blockSize-1)/blockSize, numThreadBlocks); int gridSize = std::min((threads+blockSize-1)/blockSize, numThreadBlocks);
hipError_t result = hipModuleLaunchKernel(kernel, gridSize, 1, 1, blockSize, 1, 1, sharedSize, currentStream, arguments, NULL); hipError_t result = hipModuleLaunchKernel(kernel, gridSize, 1, 1, blockSize, 1, 1, sharedSize, getCurrentStream(), arguments, NULL);
if (result != hipSuccess) { if (result != hipSuccess) {
stringstream str; stringstream str;
str<<"Error invoking kernel: "<<getErrorString(result)<<" ("<<result<<")"; str<<"Error invoking kernel: "<<getErrorString(result)<<" ("<<result<<")";
...@@ -741,7 +736,7 @@ void HipContext::executeKernelFlat(hipFunction_t kernel, void** arguments, int t ...@@ -741,7 +736,7 @@ void HipContext::executeKernelFlat(hipFunction_t kernel, void** arguments, int t
if (blockSize == -1) if (blockSize == -1)
blockSize = ThreadBlockSize; blockSize = ThreadBlockSize;
int gridSize = (threads+blockSize-1)/blockSize; int gridSize = (threads+blockSize-1)/blockSize;
hipError_t result = hipModuleLaunchKernel(kernel, gridSize, 1, 1, blockSize, 1, 1, sharedSize, currentStream, arguments, NULL); hipError_t result = hipModuleLaunchKernel(kernel, gridSize, 1, 1, blockSize, 1, 1, sharedSize, getCurrentStream(), arguments, NULL);
if (result != hipSuccess) { if (result != hipSuccess) {
stringstream str; stringstream str;
str<<"Error invoking kernel: "<<getErrorString(result)<<" ("<<result<<")"; str<<"Error invoking kernel: "<<getErrorString(result)<<" ("<<result<<")";
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#include "HipIntegrationUtilities.h" #include "HipIntegrationUtilities.h"
#include "HipNonbondedUtilities.h" #include "HipNonbondedUtilities.h"
#include "HipKernelSources.h" #include "HipKernelSources.h"
#include "HipQueue.h"
#include "SimTKOpenMMRealType.h" #include "SimTKOpenMMRealType.h"
#include "SimTKOpenMMUtilities.h" #include "SimTKOpenMMUtilities.h"
#include <algorithm> #include <algorithm>
...@@ -208,17 +209,17 @@ private: ...@@ -208,17 +209,17 @@ private:
class HipCalcNonbondedForceKernel::SyncStreamPreComputation : public HipContext::ForcePreComputation { class HipCalcNonbondedForceKernel::SyncStreamPreComputation : public HipContext::ForcePreComputation {
public: public:
SyncStreamPreComputation(HipContext& cu, hipStream_t stream, hipEvent_t event, int forceGroup) : cu(cu), stream(stream), event(event), forceGroup(forceGroup) { SyncStreamPreComputation(HipContext& cu, ComputeQueue queue, hipEvent_t event, int forceGroup) : cu(cu), queue(queue), event(event), forceGroup(forceGroup) {
} }
void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) { void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<forceGroup)) != 0) { if ((groups&(1<<forceGroup)) != 0) {
hipEventRecord(event, cu.getCurrentStream()); hipEventRecord(event, cu.getCurrentStream());
hipStreamWaitEvent(stream, event, 0); hipStreamWaitEvent(dynamic_cast<HipQueue*>(queue.get())->getStream(), event, 0);
} }
} }
private: private:
HipContext& cu; HipContext& cu;
hipStream_t stream; ComputeQueue queue;
hipEvent_t event; hipEvent_t event;
int forceGroup; int forceGroup;
}; };
...@@ -259,7 +260,6 @@ HipCalcNonbondedForceKernel::~HipCalcNonbondedForceKernel() { ...@@ -259,7 +260,6 @@ HipCalcNonbondedForceKernel::~HipCalcNonbondedForceKernel() {
delete pmeio; delete pmeio;
if (hasInitializedFFT) { if (hasInitializedFFT) {
if (usePmeStream) { if (usePmeStream) {
hipStreamDestroy(pmeStream);
hipEventDestroy(pmeSyncEvent); hipEventDestroy(pmeSyncEvent);
hipEventDestroy(paramsSyncEvent); hipEventDestroy(paramsSyncEvent);
} }
...@@ -542,17 +542,16 @@ void HipCalcNonbondedForceKernel::initialize(const System& system, const Nonbond ...@@ -542,17 +542,16 @@ void HipCalcNonbondedForceKernel::initialize(const System& system, const Nonbond
// Prepare for doing PME on its own stream. // Prepare for doing PME on its own stream.
if (usePmeStream) { if (usePmeStream) {
CHECK_RESULT(hipStreamCreateWithFlags(&pmeStream, hipStreamNonBlocking), "Error creating stream for NonbondedForce"); pmeQueue = cu.createQueue();
CHECK_RESULT(hipEventCreateWithFlags(&pmeSyncEvent, cu.getEventFlags()), "Error creating event for NonbondedForce"); CHECK_RESULT(hipEventCreateWithFlags(&pmeSyncEvent, cu.getEventFlags()), "Error creating event for NonbondedForce");
CHECK_RESULT(hipEventCreateWithFlags(&paramsSyncEvent, cu.getEventFlags()), "Error creating event for NonbondedForce"); CHECK_RESULT(hipEventCreateWithFlags(&paramsSyncEvent, cu.getEventFlags()), "Error creating event for NonbondedForce");
int recipForceGroup = force.getReciprocalSpaceForceGroup(); int recipForceGroup = force.getReciprocalSpaceForceGroup();
if (recipForceGroup < 0) if (recipForceGroup < 0)
recipForceGroup = force.getForceGroup(); recipForceGroup = force.getForceGroup();
cu.addPreComputation(new SyncStreamPreComputation(cu, pmeStream, pmeSyncEvent, recipForceGroup)); cu.addPreComputation(new SyncStreamPreComputation(cu, pmeQueue, pmeSyncEvent, recipForceGroup));
cu.addPostComputation(new SyncStreamPostComputation(cu, pmeSyncEvent, cu.getKernel(module, "addEnergy"), pmeEnergyBuffer, recipForceGroup)); cu.addPostComputation(new SyncStreamPostComputation(cu, pmeSyncEvent, cu.getKernel(module, "addEnergy"), pmeEnergyBuffer, recipForceGroup));
} }
hipStream_t fftStream = usePmeStream ? pmeStream : cu.getCurrentStream();
fft = cu.createFFT(gridSizeX, gridSizeY, gridSizeZ, true); fft = cu.createFFT(gridSizeX, gridSizeY, gridSizeZ, true);
if (doLJPME) if (doLJPME)
dispersionFft = cu.createFFT(dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true); dispersionFft = cu.createFFT(dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true);
...@@ -857,7 +856,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -857,7 +856,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
} }
if (usePmeStream) { if (usePmeStream) {
hipEventRecord(paramsSyncEvent, cu.getCurrentStream()); hipEventRecord(paramsSyncEvent, cu.getCurrentStream());
hipStreamWaitEvent(pmeStream, paramsSyncEvent, 0); hipStreamWaitEvent(dynamic_cast<HipQueue*>(pmeQueue.get())->getStream(), paramsSyncEvent, 0);
} }
if (hasOffsets) { if (hasOffsets) {
// The Ewald self energy was computed in the kernel. // The Ewald self energy was computed in the kernel.
...@@ -893,7 +892,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -893,7 +892,7 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
} }
if (pmeGrid1.isInitialized() && includeReciprocal) { if (pmeGrid1.isInitialized() && includeReciprocal) {
if (usePmeStream) if (usePmeStream)
cu.setCurrentStream(pmeStream); cu.setCurrentQueue(pmeQueue);
// Invert the periodic box vectors. // Invert the periodic box vectors.
...@@ -1015,8 +1014,8 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -1015,8 +1014,8 @@ double HipCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
cu.executeKernelFlat(pmeInterpolateDispersionForceKernel, interpolateArgs, cu.getNumAtoms(), 128); cu.executeKernelFlat(pmeInterpolateDispersionForceKernel, interpolateArgs, cu.getNumAtoms(), 128);
} }
if (usePmeStream) { if (usePmeStream) {
hipEventRecord(pmeSyncEvent, pmeStream); hipEventRecord(pmeSyncEvent, dynamic_cast<HipQueue*>(pmeQueue.get())->getStream());
cu.restoreDefaultStream(); cu.restoreDefaultQueue();
} }
} }
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "HipQueue.h"
#include "HipContext.h"
#include "openmm/OpenMMException.h"
using namespace OpenMM;
HipQueue::HipQueue(hipStream_t stream) : stream(stream), initialized(false) {
}
HipQueue::HipQueue() : initialized(false) {
hipError_t result = hipStreamCreateWithFlags(&stream, hipStreamNonBlocking);
if (result != hipSuccess)
throw OpenMMException("Error creating HIP stream: "+HipContext::getErrorString(result));
initialized = true;
}
HipQueue::~HipQueue() {
if (initialized)
hipStreamDestroy(stream);
}
...@@ -187,6 +187,10 @@ public: ...@@ -187,6 +187,10 @@ public:
* Get the context this array belongs to. * Get the context this array belongs to.
*/ */
ComputeContext& getContext(); ComputeContext& getContext();
/**
* Get the queue in which to perform transfers.
*/
cl::CommandQueue getQueue() const;
/** /**
* Get the OpenCL Buffer object. * Get the OpenCL Buffer object.
*/ */
......
...@@ -203,17 +203,13 @@ public: ...@@ -203,17 +203,13 @@ public:
*/ */
double& getEnergyWorkspace(); double& getEnergyWorkspace();
/** /**
* Get the cl::CommandQueue currently being used for execution. * Create a new ComputeQueue for use with this context.
*/
cl::CommandQueue& getQueue();
/**
* Set the cl::ComandQueue to use for execution.
*/ */
void setQueue(cl::CommandQueue& queue); ComputeQueue createQueue();
/** /*
* Reset the context to using the default queue for execution. * Get the cl::CommandQueue currently being used for execution.
*/ */
void restoreDefaultQueue(); cl::CommandQueue getQueue();
/** /**
* Construct an uninitialized array of the appropriate class for this platform. The returned * Construct an uninitialized array of the appropriate class for this platform. The returned
* value should be created on the heap with the "new" operator. * value should be created on the heap with the "new" operator.
...@@ -706,7 +702,6 @@ private: ...@@ -706,7 +702,6 @@ private:
std::map<std::string, std::string> compilationDefines; std::map<std::string, std::string> compilationDefines;
cl::Context context; cl::Context context;
cl::Device device; cl::Device device;
cl::CommandQueue defaultQueue, currentQueue;
cl::Kernel clearBufferKernel; cl::Kernel clearBufferKernel;
cl::Kernel clearTwoBuffersKernel; cl::Kernel clearTwoBuffersKernel;
cl::Kernel clearThreeBuffersKernel; cl::Kernel clearThreeBuffersKernel;
......
...@@ -185,7 +185,7 @@ private: ...@@ -185,7 +185,7 @@ private:
OpenCLArray pmeEnergyBuffer; OpenCLArray pmeEnergyBuffer;
OpenCLArray chargeBuffer; OpenCLArray chargeBuffer;
OpenCLSort* sort; OpenCLSort* sort;
cl::CommandQueue pmeQueue; ComputeQueue pmeQueue;
cl::Event pmeSyncEvent; cl::Event pmeSyncEvent;
OpenCLFFT3D* fft; OpenCLFFT3D* fft;
OpenCLFFT3D* dispersionFft; OpenCLFFT3D* dispersionFft;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment