Unverified Commit a96534c1 authored by Anton Gorenko's avatar Anton Gorenko
Browse files

Optimize findInteractingBlocks

Optimize findBlocksWithInteractions

* Replace volatile shared mem accesses with shuffles;
* Add NUM_TILES_IN_BATCH for processing block1 by multiple warps
  (for small systems);
* Cherry-pick missing changes from .cu;
* Tune MAX_BITS_FOR_PAIRS depending on device and the system size;
* Store single pairs immediately (if there are any), this allows not to
  store flags to shared memory and filter buffer and flagsBuffer after
  saving single pairs;
* Use fma explicitly and sign bit for better device code;
* Use CDNA's MFMA with singe/mixed precision;
* On CDNA the coarse grained stage processes warpSize blocks for
  one block1, the fine grained stage checks atoms of two block2 vs atoms
  of the same block1, singlePairs and interactingAtoms are also stored
  by warps, not half-warps;

Optimize findBlockBounds

* Use shuffles;
* Use executeKernelFlat;
* Process 2 tiles per warp 64 on CDNA;
* Use more uniformly distributed keys when sorting blocks;

Use compareInt2LargeSIMD when tile size < SIMD width

Fix exclusion tiles sorting on AMD CDNA (64 threads per wave)

    The nonbonded kernel uses USE_NEIGHBOR_LIST (useNeighborList)
    so host code also must check it instead of useCutoff.

    See also https://github.com/openmm/openmm/issues/3462
parent 67f5644d
......@@ -284,7 +284,17 @@ public:
* @param blockSize the size of each thread block to use
* @param sharedSize the amount of dynamic shared memory to allocated for the kernel, in bytes
*/
void executeKernel(hipFunction_t kernel, void** arguments, int workUnits, int blockSize = -1, unsigned int sharedSize = 0);
void executeKernel(hipFunction_t kernel, void** arguments, int threads, int blockSize = -1, unsigned int sharedSize = 0);
/**
* Execute a kernel with full grid.
*
* @param kernel the kernel to execute
* @param arguments an array of pointers to the kernel arguments
* @param threads the total number of threads that should be used
* @param blockSize the size of each thread block to use
* @param sharedSize the amount of dynamic shared memory to allocated for the kernel, in bytes
*/
void executeKernelFlat(hipFunction_t kernel, void** arguments, int threads, int blockSize = -1, unsigned int sharedSize = 0);
/**
* Compute the largest thread block size that can be used for a kernel that requires a particular amount of
* shared memory per thread.
......
......@@ -349,7 +349,8 @@ private:
std::map<int, std::string> groupKernelSource;
double lastCutoff;
bool useCutoff, usePeriodic, anyExclusions, usePadding, forceRebuildNeighborList, canUsePairList;
int startTileIndex, startBlockIndex, numBlocks, maxTiles, maxSinglePairs, maxExclusions, numForceThreadBlocks, forceThreadBlockSize, numAtoms, groupFlags;
int startTileIndex, startBlockIndex, numBlocks, maxTiles, maxSinglePairs, numTilesInBatch, maxExclusions;
int numForceThreadBlocks, forceThreadBlockSize, findInteractingBlocksThreadBlockSize, numAtoms, groupFlags;
long long numTiles;
std::string kernelSource;
};
......
......@@ -73,6 +73,7 @@ HipNonbondedUtilities::HipNonbondedUtilities(HipContext& context) : context(cont
CHECK_RESULT(hipHostMalloc((void**) &pinnedCountBuffer, 2*sizeof(unsigned int), hipHostMallocPortable));
numForceThreadBlocks = 5*4*context.getMultiprocessors();
forceThreadBlockSize = 64;
findInteractingBlocksThreadBlockSize = context.getSIMDWidth();
setKernelSource(HipKernelSources::nonbonded);
}
......@@ -170,6 +171,20 @@ static bool compareInt2(int2 a, int2 b) {
return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}
static bool compareInt2LargeSIMD(int2 a, int2 b) {
// This version is used on devices with SIMD width greater than tile size. It puts diagonal tiles before off-diagonal
// ones to reduce thread divergence.
if (a.x == a.y) {
if (b.x == b.y)
return (a.x < b.x);
return true;
}
if (b.x == b.y)
return false;
return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}
void HipNonbondedUtilities::initialize(const System& system) {
string errorMessage = "Error initializing nonbonded utilities";
if (atomExclusions.size() == 0) {
......@@ -201,7 +216,7 @@ void HipNonbondedUtilities::initialize(const System& system) {
vector<int2> exclusionTilesVec;
for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter)
exclusionTilesVec.push_back(make_int2(iter->first, iter->second));
sort(exclusionTilesVec.begin(), exclusionTilesVec.end(), compareInt2);
sort(exclusionTilesVec.begin(), exclusionTilesVec.end(), context.getSIMDWidth() <= 32 || !useNeighborList ? compareInt2 : compareInt2LargeSIMD);
exclusionTiles.initialize<int2>(context, exclusionTilesVec.size(), "exclusionTiles");
exclusionTiles.upload(exclusionTilesVec);
map<pair<int, int>, int> exclusionTileMap;
......@@ -266,6 +281,8 @@ void HipNonbondedUtilities::initialize(const System& system) {
if (maxTiles < 1)
maxTiles = 1;
maxSinglePairs = 5*numAtoms;
// HIP-TODO: This may require tuning
numTilesInBatch = numAtomBlocks < 2000 ? 4 : 1;
interactingTiles.initialize<int>(context, maxTiles, "interactingTiles");
interactingAtoms.initialize<int>(context, HipContext::TileSize*maxTiles, "interactingAtoms");
interactionCount.initialize<unsigned int>(context, 2, "interactionCount");
......@@ -393,10 +410,10 @@ void HipNonbondedUtilities::prepareInteractions(int forceGroups) {
if (lastCutoff != kernels.cutoffDistance)
forceRebuildNeighborList = true;
context.executeKernel(kernels.findBlockBoundsKernel, &findBlockBoundsArgs[0], context.getNumAtoms());
context.executeKernelFlat(kernels.findBlockBoundsKernel, &findBlockBoundsArgs[0], context.getPaddedNumAtoms(), context.getSIMDWidth());
blockSorter->sort(sortedBlocks);
context.executeKernel(kernels.sortBoxDataKernel, &sortBoxDataArgs[0], context.getNumAtoms());
context.executeKernel(kernels.findInteractingBlocksKernel, &findInteractingBlocksArgs[0], context.getNumAtoms(), 256);
context.executeKernelFlat(kernels.sortBoxDataKernel, &sortBoxDataArgs[0], context.getNumAtoms(), 64);
context.executeKernelFlat(kernels.findInteractingBlocksKernel, &findInteractingBlocksArgs[0], context.getNumAtomBlocks() * context.getSIMDWidth() * numTilesInBatch, findInteractingBlocksThreadBlockSize);
forceRebuildNeighborList = false;
lastCutoff = kernels.cutoffDistance;
interactionCount.download(pinnedCountBuffer, false);
......@@ -488,6 +505,7 @@ void HipNonbondedUtilities::createKernelsForGroups(int groups) {
defines["TILE_SIZE"] = context.intToString(HipContext::TileSize);
defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
defines["PADDING"] = context.doubleToString(paddedCutoff-cutoff);
defines["PADDED_CUTOFF"] = context.doubleToString(paddedCutoff);
defines["PADDED_CUTOFF_SQUARED"] = context.doubleToString(paddedCutoff*paddedCutoff);
......@@ -497,7 +515,33 @@ void HipNonbondedUtilities::createKernelsForGroups(int groups) {
if (context.getBoxIsTriclinic())
defines["TRICLINIC"] = "1";
defines["MAX_EXCLUSIONS"] = context.intToString(maxExclusions);
defines["MAX_BITS_FOR_PAIRS"] = (canUsePairList ? "2" : "0");
int maxBits = 0;
if (canUsePairList) {
if (context.getUseDoublePrecision()) {
maxBits = 4;
}
else {
if (context.getSIMDWidth() > 32) {
// CDNA
if (context.getNumAtoms() < 100000)
maxBits = 4;
else // Large systems
maxBits = 0;
}
else {
// RDNA
if (context.getNumAtoms() < 100000)
maxBits = 4;
else if (context.getNumAtoms() < 500000)
maxBits = 2;
else // Very large systems
maxBits = 0;
}
}
}
defines["MAX_BITS_FOR_PAIRS"] = context.intToString(maxBits);
defines["NUM_TILES_IN_BATCH"] = context.intToString(numTilesInBatch);
defines["GROUP_SIZE"] = context.intToString(findInteractingBlocksThreadBlockSize);
hipModule_t interactingBlocksProgram = context.createModule(HipKernelSources::vectorOps+HipKernelSources::findInteractingBlocks, defines);
kernels.findBlockBoundsKernel = context.getKernel(interactingBlocksProgram, "findBlockBounds");
kernels.sortBoxDataKernel = context.getKernel(interactingBlocksProgram, "sortBoxData");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment