Commit b1a21fd4 authored by one's avatar one
Browse files

Enable split PME streams for HIP LJPME

Run Coulomb and dispersion reciprocal PME work on separate HIP queues for
LJPME when PME streams are enabled.  Use separate grids, sorters, events, and
energy buffers so the two reciprocal branches can overlap safely.

Keep the behavior HIP-only based on RTX4090 CUDA profiling, where the same
split increased PME spread/list contention and regressed apoa1ljpme.
parent c26187aa
......@@ -46,7 +46,8 @@ class CommonCalcNonbondedForceKernel : public CalcNonbondedForceKernel {
public:
CommonCalcNonbondedForceKernel(std::string name, const Platform& platform, ComputeContext& cc, const System& system) : CalcNonbondedForceKernel(name, platform),
hasInitializedKernel(false), cc(cc), pmeio(NULL), stepsToSort(0), dispersionStepsToSort(0),
pmeGridIndexBlockSize(-1), pmeSpreadChargeBlockSize(-1), pmeFinishSpreadChargeBlockSize(-1) {
pmeGridIndexBlockSize(-1), pmeSpreadChargeBlockSize(-1), pmeFinishSpreadChargeBlockSize(-1),
useSplitLJPMEStream(false), dispersionSyncQueue(NULL) {
}
~CommonCalcNonbondedForceKernel();
/**
......@@ -134,6 +135,8 @@ private:
ComputeArray cosSinSums;
ComputeArray pmeGrid1;
ComputeArray pmeGrid2;
ComputeArray pmeDispersionGrid1;
ComputeArray pmeDispersionGrid2;
ComputeArray pmeBsplineModuliX;
ComputeArray pmeBsplineModuliY;
ComputeArray pmeBsplineModuliZ;
......@@ -143,10 +146,13 @@ private:
ComputeArray pmeAtomGridIndex;
ComputeArray pmeDispersionAtomGridIndex;
ComputeArray pmeEnergyBuffer;
ComputeArray pmeDispersionEnergyBuffer;
ComputeArray chargeBuffer;
ComputeSort sort;
ComputeSort pmeDispersionSort;
ComputeQueue pmeQueue;
ComputeEvent pmeSyncEvent, paramsSyncEvent;
ComputeQueue pmeDispersionQueue;
ComputeEvent pmeSyncEvent, pmeDispersionSyncEvent, paramsSyncEvent;
FFT3D fft, dispersionFft;
Kernel cpuPme;
PmeIO* pmeio;
......@@ -172,6 +178,8 @@ private:
int dispersionStepsToSort;
bool usePmeQueue, deviceIsCpu, useFixedPointChargeSpreading, useCpuPme;
int pmeGridIndexBlockSize, pmeSpreadChargeBlockSize, pmeFinishSpreadChargeBlockSize;
bool useSplitLJPMEStream;
SyncQueuePostComputation* dispersionSyncQueue;
bool hasCoulomb, hasLJ, doLJPME, usePosqCharges, recomputeParams, hasOffsets;
NonbondedMethod nonbondedMethod;
static const int PmeOrder = 5;
......
......@@ -286,6 +286,7 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons
doLJPME = (nonbondedMethod == LJPME && hasLJ);
usePosqCharges = hasCoulomb ? cc.requestPosqCharges() : false;
bool isHip = (getPlatform().getName() == "HIP");
useSplitLJPMEStream = (isHip && cc.getSIMDWidth() == 64 && usePmeQueue && doLJPME && hasCoulomb);
bool useLargeHipPmeBlocks = (isHip && cc.getNumAtomBlocks() >= 2000);
pmeGridIndexBlockSize = useLargeHipPmeBlocks ? 128 : -1;
pmeSpreadChargeBlockSize = useLargeHipPmeBlocks ? 128 : -1;
......@@ -443,16 +444,25 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons
// Create required data structures.
int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
int gridElements = gridSizeX*gridSizeY*gridSizeZ;
if (doLJPME) {
gridElements = max(gridElements, dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ);
}
int coulombGridElements = gridSizeX*gridSizeY*gridSizeZ;
int dispersionGridElements = (doLJPME ? dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ : 0);
int gridElements = (useSplitLJPMEStream ? coulombGridElements : max(coulombGridElements, dispersionGridElements));
pmeGrid1.initialize(cc, gridElements, 2*elementSize, "pmeGrid1");
pmeGrid2.initialize(cc, gridElements, 2*elementSize, "pmeGrid2");
if (useSplitLJPMEStream) {
pmeDispersionGrid1.initialize(cc, dispersionGridElements, 2*elementSize, "pmeDispersionGrid1");
pmeDispersionGrid2.initialize(cc, dispersionGridElements, 2*elementSize, "pmeDispersionGrid2");
}
if (useFixedPointChargeSpreading)
cc.addAutoclearBuffer(pmeGrid2);
else
cc.addAutoclearBuffer(pmeGrid1);
if (useSplitLJPMEStream) {
if (useFixedPointChargeSpreading)
cc.addAutoclearBuffer(pmeDispersionGrid2);
else
cc.addAutoclearBuffer(pmeDispersionGrid1);
}
pmeBsplineModuliX.initialize(cc, gridSizeX, elementSize, "pmeBsplineModuliX");
pmeBsplineModuliY.initialize(cc, gridSizeY, elementSize, "pmeBsplineModuliY");
pmeBsplineModuliZ.initialize(cc, gridSizeZ, elementSize, "pmeBsplineModuliZ");
......@@ -462,12 +472,20 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons
pmeDispersionBsplineModuliZ.initialize(cc, dispersionGridSizeZ, elementSize, "pmeDispersionBsplineModuliZ");
}
pmeAtomGridIndex.initialize<mm_int2>(cc, numParticles, "pmeAtomGridIndex");
// Keep a separate LJPME atom-grid index for dispersion sorting.
if (doLJPME)
pmeDispersionAtomGridIndex.initialize<mm_int2>(cc, numParticles, "pmeDispersionAtomGridIndex");
int energyElementSize = (cc.getUseDoublePrecision() || cc.getUseMixedPrecision() ? sizeof(double) : sizeof(float));
pmeEnergyBuffer.initialize(cc, cc.getNumThreadBlocks()*ComputeContext::ThreadBlockSize, energyElementSize, "pmeEnergyBuffer");
cc.clearBuffer(pmeEnergyBuffer);
if (useSplitLJPMEStream) {
pmeDispersionEnergyBuffer.initialize(cc, cc.getNumThreadBlocks()*ComputeContext::ThreadBlockSize, energyElementSize, "pmeDispersionEnergyBuffer");
cc.clearBuffer(pmeDispersionEnergyBuffer);
}
sort = cc.createSort(new SortTrait(), cc.getNumAtoms());
// Use a separate sorter because ComputeSort owns scratch buffers.
if (useSplitLJPMEStream)
pmeDispersionSort = cc.createSort(new SortTrait(), cc.getNumAtoms());
fft = cc.createFFT(gridSizeX, gridSizeY, gridSizeZ, true);
if (doLJPME)
dispersionFft = cc.createFFT(dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true);
......@@ -482,6 +500,14 @@ void CommonCalcNonbondedForceKernel::commonInitialize(const System& system, cons
paramsSyncEvent = cc.createEvent();
cc.addPreComputation(new SyncQueuePreComputation(cc, pmeQueue, pmeSyncEvent, recipForceGroup));
cc.addPostComputation(syncQueue = new SyncQueuePostComputation(cc, pmeSyncEvent, pmeEnergyBuffer, recipForceGroup));
if (useSplitLJPMEStream) {
// Create synchronization for the dispersion PME queue.
pmeDispersionQueue = cc.createQueue();
ComputeEvent pmeDispersionStartEvent = cc.createEvent();
pmeDispersionSyncEvent = cc.createEvent();
cc.addPreComputation(new SyncQueuePreComputation(cc, pmeDispersionQueue, pmeDispersionStartEvent, recipForceGroup));
cc.addPostComputation(dispersionSyncQueue = new SyncQueuePostComputation(cc, pmeDispersionSyncEvent, pmeDispersionEnergyBuffer, recipForceGroup));
}
}
// Initialize the b-spline moduli.
......@@ -822,8 +848,11 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeFinishSpreadChargeKernel->addArg(pmeGrid2);
pmeFinishSpreadChargeKernel->addArg(pmeGrid1);
}
if (usePmeQueue)
if (usePmeQueue) {
syncQueue->setKernel(program->createKernel("addEnergy"));
if (useSplitLJPMEStream)
dispersionSyncQueue->setKernel(program->createKernel("addEnergy"));
}
if (doLJPME) {
// Create kernels for LJ PME.
......@@ -837,6 +866,8 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeDefines["USE_LJPME"] = "1";
pmeDefines["CHARGE_FROM_SIGEPS"] = "1";
program = cc.compileProgram(CommonKernelSources::pme, pmeDefines);
ComputeArray& dispersionGrid1 = (useSplitLJPMEStream ? pmeDispersionGrid1 : pmeGrid1);
ComputeArray& dispersionGrid2 = (useSplitLJPMEStream ? pmeDispersionGrid2 : pmeGrid2);
pmeDispersionGridIndexKernel = program->createKernel("findAtomGridIndex");
pmeDispersionSpreadChargeKernel = program->createKernel("gridSpreadCharge");
pmeDispersionConvolutionKernel = program->createKernel("reciprocalConvolution");
......@@ -848,21 +879,23 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeDispersionGridIndexKernel->addArg();
pmeDispersionSpreadChargeKernel->addArg(cc.getPosq());
if (useFixedPointChargeSpreading)
pmeDispersionSpreadChargeKernel->addArg(pmeGrid2);
pmeDispersionSpreadChargeKernel->addArg(dispersionGrid2);
else
pmeDispersionSpreadChargeKernel->addArg(pmeGrid1);
pmeDispersionSpreadChargeKernel->addArg(dispersionGrid1);
for (int i = 0; i < 8; i++)
pmeDispersionSpreadChargeKernel->addArg();
pmeDispersionSpreadChargeKernel->addArg(pmeDispersionAtomGridIndex);
pmeDispersionSpreadChargeKernel->addArg(sigmaEpsilon);
pmeDispersionConvolutionKernel->addArg(pmeGrid2);
pmeDispersionConvolutionKernel->addArg(dispersionGrid2);
pmeDispersionConvolutionKernel->addArg(pmeDispersionBsplineModuliX);
pmeDispersionConvolutionKernel->addArg(pmeDispersionBsplineModuliY);
pmeDispersionConvolutionKernel->addArg(pmeDispersionBsplineModuliZ);
for (int i = 0; i < 3; i++)
pmeDispersionConvolutionKernel->addArg();
pmeDispersionEvalEnergyKernel->addArg(pmeGrid2);
if (usePmeQueue)
pmeDispersionEvalEnergyKernel->addArg(dispersionGrid2);
if (useSplitLJPMEStream)
pmeDispersionEvalEnergyKernel->addArg(pmeDispersionEnergyBuffer);
else if (usePmeQueue)
pmeDispersionEvalEnergyKernel->addArg(pmeEnergyBuffer);
else
pmeDispersionEvalEnergyKernel->addArg(cc.getEnergyBuffer());
......@@ -873,15 +906,15 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeDispersionEvalEnergyKernel->addArg();
pmeDispersionInterpolateForceKernel->addArg(cc.getPosq());
pmeDispersionInterpolateForceKernel->addArg(cc.getLongForceBuffer());
pmeDispersionInterpolateForceKernel->addArg(pmeGrid1);
pmeDispersionInterpolateForceKernel->addArg(dispersionGrid1);
for (int i = 0; i < 8; i++)
pmeDispersionInterpolateForceKernel->addArg();
pmeDispersionInterpolateForceKernel->addArg(pmeDispersionAtomGridIndex);
pmeDispersionInterpolateForceKernel->addArg(sigmaEpsilon);
if (useFixedPointChargeSpreading) {
pmeDispersionFinishSpreadChargeKernel = program->createKernel("finishSpreadCharge");
pmeDispersionFinishSpreadChargeKernel->addArg(pmeGrid2);
pmeDispersionFinishSpreadChargeKernel->addArg(pmeGrid1);
pmeDispersionFinishSpreadChargeKernel->addArg(dispersionGrid2);
pmeDispersionFinishSpreadChargeKernel->addArg(dispersionGrid1);
}
}
}
......@@ -911,6 +944,8 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
if (usePmeQueue) {
paramsSyncEvent->enqueue();
paramsSyncEvent->queueWait(pmeQueue);
if (useSplitLJPMEStream)
paramsSyncEvent->queueWait(pmeDispersionQueue);
}
if (hasOffsets) {
// The Ewald self energy was computed in the kernel.
......@@ -968,6 +1003,8 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
// Execute the reciprocal space kernels.
int coulombGridElements = gridSizeX*gridSizeY*gridSizeZ;
int dispersionGridElements = (doLJPME ? dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ : 0);
if (hasCoulomb) {
if (stepsToSort <= 0) {
setPeriodicBoxArgs(cc, pmeGridIndexKernel, 2);
......@@ -1000,7 +1037,7 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
}
pmeSpreadChargeKernel->execute(cc.getNumAtoms(), pmeSpreadChargeBlockSize);
if (useFixedPointChargeSpreading)
pmeFinishSpreadChargeKernel->execute(gridSizeX*gridSizeY*gridSizeZ, pmeFinishSpreadChargeBlockSize);
pmeFinishSpreadChargeKernel->execute(coulombGridElements, pmeFinishSpreadChargeBlockSize);
fft->execFFT(pmeGrid1, pmeGrid2, true);
if (cc.getUseDoublePrecision()) {
pmeConvolutionKernel->setArg<mm_double4>(4, recipBoxVectors[0]);
......@@ -1019,8 +1056,8 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeEvalEnergyKernel->setArg<mm_float4>(7, recipBoxVectorsFloat[2]);
}
if (includeEnergy)
pmeEvalEnergyKernel->execute(gridSizeX*gridSizeY*gridSizeZ);
pmeConvolutionKernel->execute(gridSizeX*gridSizeY*gridSizeZ);
pmeEvalEnergyKernel->execute(coulombGridElements);
pmeConvolutionKernel->execute(coulombGridElements);
fft->execFFT(pmeGrid2, pmeGrid1, false);
setPeriodicBoxArgs(cc, pmeInterpolateForceKernel, 3);
if (cc.getUseDoublePrecision()) {
......@@ -1039,7 +1076,14 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeInterpolateForceKernel->execute(cc.getNumAtoms());
}
if (useSplitLJPMEStream) {
// Record the Coulomb PME completion event before switching queues.
pmeSyncEvent->enqueue();
cc.setCurrentQueue(pmeDispersionQueue);
}
if (doLJPME && hasLJ) {
ComputeArray& dispersionGrid1 = (useSplitLJPMEStream ? pmeDispersionGrid1 : pmeGrid1);
ComputeArray& dispersionGrid2 = (useSplitLJPMEStream ? pmeDispersionGrid2 : pmeGrid2);
if (dispersionStepsToSort <= 0) {
setPeriodicBoxArgs(cc, pmeDispersionGridIndexKernel, 2);
if (cc.getUseDoublePrecision()) {
......@@ -1053,15 +1097,20 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeDispersionGridIndexKernel->setArg(9, recipBoxVectorsFloat[2]);
}
pmeDispersionGridIndexKernel->execute(cc.getNumAtoms(), pmeGridIndexBlockSize);
sort->sort(pmeDispersionAtomGridIndex);
if (useSplitLJPMEStream)
pmeDispersionSort->sort(pmeDispersionAtomGridIndex);
else
sort->sort(pmeDispersionAtomGridIndex);
dispersionStepsToSort = 3;
}
else
dispersionStepsToSort--;
if (useFixedPointChargeSpreading)
cc.clearBuffer(pmeGrid2);
else
cc.clearBuffer(pmeGrid1);
if (!useSplitLJPMEStream) {
if (useFixedPointChargeSpreading)
cc.clearBuffer(dispersionGrid2);
else
cc.clearBuffer(dispersionGrid1);
}
setPeriodicBoxArgs(cc, pmeDispersionSpreadChargeKernel, 2);
if (cc.getUseDoublePrecision()) {
pmeDispersionSpreadChargeKernel->setArg(7, recipBoxVectors[0]);
......@@ -1075,8 +1124,8 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
}
pmeDispersionSpreadChargeKernel->execute(cc.getNumAtoms(), pmeSpreadChargeBlockSize);
if (useFixedPointChargeSpreading)
pmeDispersionFinishSpreadChargeKernel->execute(dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ, pmeFinishSpreadChargeBlockSize);
dispersionFft->execFFT(pmeGrid1, pmeGrid2, true);
pmeDispersionFinishSpreadChargeKernel->execute(dispersionGridElements, pmeFinishSpreadChargeBlockSize);
dispersionFft->execFFT(dispersionGrid1, dispersionGrid2, true);
if (cc.getUseDoublePrecision()) {
pmeDispersionConvolutionKernel->setArg(4, recipBoxVectors[0]);
pmeDispersionConvolutionKernel->setArg(5, recipBoxVectors[1]);
......@@ -1093,12 +1142,14 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeDispersionEvalEnergyKernel->setArg(6, recipBoxVectorsFloat[1]);
pmeDispersionEvalEnergyKernel->setArg(7, recipBoxVectorsFloat[2]);
}
if (!hasCoulomb)
if (useSplitLJPMEStream && includeEnergy)
cc.clearBuffer(pmeDispersionEnergyBuffer);
else if (!hasCoulomb)
cc.clearBuffer(pmeEnergyBuffer);
if (includeEnergy)
pmeDispersionEvalEnergyKernel->execute(dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ);
pmeDispersionConvolutionKernel->execute(dispersionGridSizeX*dispersionGridSizeY*dispersionGridSizeZ);
dispersionFft->execFFT(pmeGrid2, pmeGrid1, false);
pmeDispersionEvalEnergyKernel->execute(dispersionGridElements);
pmeDispersionConvolutionKernel->execute(dispersionGridElements);
dispersionFft->execFFT(dispersionGrid2, dispersionGrid1, false);
setPeriodicBoxArgs(cc, pmeDispersionInterpolateForceKernel, 3);
if (cc.getUseDoublePrecision()) {
pmeDispersionInterpolateForceKernel->setArg(8, recipBoxVectors[0]);
......@@ -1116,7 +1167,11 @@ double CommonCalcNonbondedForceKernel::execute(ContextImpl& context, bool includ
pmeDispersionInterpolateForceKernel->execute(cc.getNumAtoms());
}
if (usePmeQueue) {
pmeSyncEvent->enqueue();
// The Coulomb completion event was recorded before switching queues.
if (useSplitLJPMEStream)
pmeDispersionSyncEvent->enqueue();
else
pmeSyncEvent->enqueue();
cc.restoreDefaultQueue();
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment