Unverified Commit a96534c1 authored by Anton Gorenko's avatar Anton Gorenko
Browse files

Optimize findInteractingBlocks

Optimize findBlocksWithInteractions

* Replace volatile shared mem accesses with shuffles;
* Add NUM_TILES_IN_BATCH for processing block1 by multiple warps
  (for small systems);
* Cherry-pick missing changes from .cu;
* Tune MAX_BITS_FOR_PAIRS depending on device and the system size;
* Store single pairs immediately (if there are any), this allows not to
  store flags to shared memory and filter buffer and flagsBuffer after
  saving single pairs;
* Use fma explicitly and sign bit for better device code;
* Use CDNA's MFMA with singe/mixed precision;
* On CDNA the coarse grained stage processes warpSize blocks for
  one block1, the fine grained stage checks atoms of two block2 vs atoms
  of the same block1, singlePairs and interactingAtoms are also stored
  by warps, not half-warps;

Optimize findBlockBounds

* Use shuffles;
* Use executeKernelFlat;
* Process 2 tiles per warp 64 on CDNA;
* Use more uniformly distributed keys when sorting blocks;

Use compareInt2LargeSIMD when tile size < SIMD width

Fix exclusion tiles sorting on AMD CDNA (64 threads per wave)

    The nonbonded kernel uses USE_NEIGHBOR_LIST (useNeighborList)
    so host code also must check it instead of useCutoff.

    See also https://github.com/openmm/openmm/issues/3462
parent 67f5644d
......@@ -284,7 +284,17 @@ public:
* @param blockSize the size of each thread block to use
* @param sharedSize the amount of dynamic shared memory to allocated for the kernel, in bytes
*/
void executeKernel(hipFunction_t kernel, void** arguments, int workUnits, int blockSize = -1, unsigned int sharedSize = 0);
void executeKernel(hipFunction_t kernel, void** arguments, int threads, int blockSize = -1, unsigned int sharedSize = 0);
/**
* Execute a kernel with full grid.
*
* @param kernel the kernel to execute
* @param arguments an array of pointers to the kernel arguments
* @param threads the total number of threads that should be used
* @param blockSize the size of each thread block to use
* @param sharedSize the amount of dynamic shared memory to allocated for the kernel, in bytes
*/
void executeKernelFlat(hipFunction_t kernel, void** arguments, int threads, int blockSize = -1, unsigned int sharedSize = 0);
/**
* Compute the largest thread block size that can be used for a kernel that requires a particular amount of
* shared memory per thread.
......
......@@ -349,7 +349,8 @@ private:
std::map<int, std::string> groupKernelSource;
double lastCutoff;
bool useCutoff, usePeriodic, anyExclusions, usePadding, forceRebuildNeighborList, canUsePairList;
int startTileIndex, startBlockIndex, numBlocks, maxTiles, maxSinglePairs, maxExclusions, numForceThreadBlocks, forceThreadBlockSize, numAtoms, groupFlags;
int startTileIndex, startBlockIndex, numBlocks, maxTiles, maxSinglePairs, numTilesInBatch, maxExclusions;
int numForceThreadBlocks, forceThreadBlockSize, findInteractingBlocksThreadBlockSize, numAtoms, groupFlags;
long long numTiles;
std::string kernelSource;
};
......
......@@ -73,6 +73,7 @@ HipNonbondedUtilities::HipNonbondedUtilities(HipContext& context) : context(cont
CHECK_RESULT(hipHostMalloc((void**) &pinnedCountBuffer, 2*sizeof(unsigned int), hipHostMallocPortable));
numForceThreadBlocks = 5*4*context.getMultiprocessors();
forceThreadBlockSize = 64;
findInteractingBlocksThreadBlockSize = context.getSIMDWidth();
setKernelSource(HipKernelSources::nonbonded);
}
......@@ -170,6 +171,20 @@ static bool compareInt2(int2 a, int2 b) {
return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}
static bool compareInt2LargeSIMD(int2 a, int2 b) {
// This version is used on devices with SIMD width greater than tile size. It puts diagonal tiles before off-diagonal
// ones to reduce thread divergence.
if (a.x == a.y) {
if (b.x == b.y)
return (a.x < b.x);
return true;
}
if (b.x == b.y)
return false;
return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}
void HipNonbondedUtilities::initialize(const System& system) {
string errorMessage = "Error initializing nonbonded utilities";
if (atomExclusions.size() == 0) {
......@@ -201,7 +216,7 @@ void HipNonbondedUtilities::initialize(const System& system) {
vector<int2> exclusionTilesVec;
for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter)
exclusionTilesVec.push_back(make_int2(iter->first, iter->second));
sort(exclusionTilesVec.begin(), exclusionTilesVec.end(), compareInt2);
sort(exclusionTilesVec.begin(), exclusionTilesVec.end(), context.getSIMDWidth() <= 32 || !useNeighborList ? compareInt2 : compareInt2LargeSIMD);
exclusionTiles.initialize<int2>(context, exclusionTilesVec.size(), "exclusionTiles");
exclusionTiles.upload(exclusionTilesVec);
map<pair<int, int>, int> exclusionTileMap;
......@@ -266,6 +281,8 @@ void HipNonbondedUtilities::initialize(const System& system) {
if (maxTiles < 1)
maxTiles = 1;
maxSinglePairs = 5*numAtoms;
// HIP-TODO: This may require tuning
numTilesInBatch = numAtomBlocks < 2000 ? 4 : 1;
interactingTiles.initialize<int>(context, maxTiles, "interactingTiles");
interactingAtoms.initialize<int>(context, HipContext::TileSize*maxTiles, "interactingAtoms");
interactionCount.initialize<unsigned int>(context, 2, "interactionCount");
......@@ -393,10 +410,10 @@ void HipNonbondedUtilities::prepareInteractions(int forceGroups) {
if (lastCutoff != kernels.cutoffDistance)
forceRebuildNeighborList = true;
context.executeKernel(kernels.findBlockBoundsKernel, &findBlockBoundsArgs[0], context.getNumAtoms());
context.executeKernelFlat(kernels.findBlockBoundsKernel, &findBlockBoundsArgs[0], context.getPaddedNumAtoms(), context.getSIMDWidth());
blockSorter->sort(sortedBlocks);
context.executeKernel(kernels.sortBoxDataKernel, &sortBoxDataArgs[0], context.getNumAtoms());
context.executeKernel(kernels.findInteractingBlocksKernel, &findInteractingBlocksArgs[0], context.getNumAtoms(), 256);
context.executeKernelFlat(kernels.sortBoxDataKernel, &sortBoxDataArgs[0], context.getNumAtoms(), 64);
context.executeKernelFlat(kernels.findInteractingBlocksKernel, &findInteractingBlocksArgs[0], context.getNumAtomBlocks() * context.getSIMDWidth() * numTilesInBatch, findInteractingBlocksThreadBlockSize);
forceRebuildNeighborList = false;
lastCutoff = kernels.cutoffDistance;
interactionCount.download(pinnedCountBuffer, false);
......@@ -488,6 +505,7 @@ void HipNonbondedUtilities::createKernelsForGroups(int groups) {
defines["TILE_SIZE"] = context.intToString(HipContext::TileSize);
defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
defines["PADDING"] = context.doubleToString(paddedCutoff-cutoff);
defines["PADDED_CUTOFF"] = context.doubleToString(paddedCutoff);
defines["PADDED_CUTOFF_SQUARED"] = context.doubleToString(paddedCutoff*paddedCutoff);
......@@ -497,7 +515,33 @@ void HipNonbondedUtilities::createKernelsForGroups(int groups) {
if (context.getBoxIsTriclinic())
defines["TRICLINIC"] = "1";
defines["MAX_EXCLUSIONS"] = context.intToString(maxExclusions);
defines["MAX_BITS_FOR_PAIRS"] = (canUsePairList ? "2" : "0");
int maxBits = 0;
if (canUsePairList) {
if (context.getUseDoublePrecision()) {
maxBits = 4;
}
else {
if (context.getSIMDWidth() > 32) {
// CDNA
if (context.getNumAtoms() < 100000)
maxBits = 4;
else // Large systems
maxBits = 0;
}
else {
// RDNA
if (context.getNumAtoms() < 100000)
maxBits = 4;
else if (context.getNumAtoms() < 500000)
maxBits = 2;
else // Very large systems
maxBits = 0;
}
}
}
defines["MAX_BITS_FOR_PAIRS"] = context.intToString(maxBits);
defines["NUM_TILES_IN_BATCH"] = context.intToString(numTilesInBatch);
defines["GROUP_SIZE"] = context.intToString(findInteractingBlocksThreadBlockSize);
hipModule_t interactingBlocksProgram = context.createModule(HipKernelSources::vectorOps+HipKernelSources::findInteractingBlocks, defines);
kernels.findBlockBoundsKernel = context.getKernel(interactingBlocksProgram, "findBlockBounds");
kernels.sortBoxDataKernel = context.getKernel(interactingBlocksProgram, "sortBoxData");
......
#define GROUP_SIZE 256
#define BUFFER_SIZE 256
__device__ inline int __TILEFLAGS_FFS(unsigned long long int x) {
return __ffsll(x);
}
__device__ inline int __TILEFLAGS_CLZ(unsigned long long int x) {
return __clzll(x);
#if defined(AMD_RDNA)
typedef unsigned int warpflags;
__device__ inline int warpPopc(warpflags x) {
return __popc(x);
}
__device__ inline int __TILEFLAGS_POPC(unsigned long long int x) {
#else
typedef unsigned long long warpflags;
__device__ inline int warpPopc(warpflags x) {
return __popcll(x);
}
#endif
/**
* Find a bounding box for the atoms in each block.
*/
extern "C" __global__ void findBlockBounds(int numAtoms, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
const real4* __restrict__ posq, real4* __restrict__ blockCenter, real4* __restrict__ blockBoundingBox, int* __restrict__ rebuildNeighborList,
real2* __restrict__ sortedBlocks) {
int index = blockIdx.x*blockDim.x+threadIdx.x;
int base = index*TILE_SIZE;
while (base < numAtoms) {
real4 pos = posq[base];
#ifdef USE_PERIODIC
APPLY_PERIODIC_TO_POS(pos)
#endif
real4 minPos = pos;
real4 maxPos = pos;
int last = min(base+TILE_SIZE, numAtoms);
for (int i = base+1; i < last; i++) {
pos = posq[i];
const int indexInTile = threadIdx.x%TILE_SIZE;
const int index = warpSize == TILE_SIZE ? blockIdx.x : (blockIdx.x*(warpSize/TILE_SIZE) + threadIdx.x/TILE_SIZE);
const int base = index * TILE_SIZE;
if (base >= numAtoms)
return;
real4 tPos = posq[base + indexInTile < numAtoms ? base + indexInTile : 0];
#ifdef USE_PERIODIC
real4 center = 0.5f*(maxPos+minPos);
APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)
#endif
minPos = make_real4(min(minPos.x,pos.x), min(minPos.y,pos.y), min(minPos.z,pos.z), 0);
maxPos = make_real4(max(maxPos.x,pos.x), max(maxPos.y,pos.y), max(maxPos.z,pos.z), 0);
}
real4 blockSize = 0.5f*(maxPos-minPos);
real4 pos;
pos.x = SHFL(tPos.x, 0);
pos.y = SHFL(tPos.y, 0);
pos.z = SHFL(tPos.z, 0);
APPLY_PERIODIC_TO_POS(pos)
real4 minPos = pos;
real4 maxPos = pos;
for (int i = 1; i < TILE_SIZE; i++) {
pos.x = SHFL(tPos.x, i);
pos.y = SHFL(tPos.y, i);
pos.z = SHFL(tPos.z, i);
real4 center = 0.5f*(maxPos+minPos);
center.w = 0;
for (int i = base; i < last; i++) {
pos = posq[i];
real4 delta = posq[i]-center;
APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)
minPos = make_real4(min(minPos.x,pos.x), min(minPos.y,pos.y), min(minPos.z,pos.z), 0);
maxPos = make_real4(max(maxPos.x,pos.x), max(maxPos.y,pos.y), max(maxPos.z,pos.z), 0);
}
#else
real4 minPos = tPos;
real4 maxPos = tPos;
for (int i = TILE_SIZE >> 1; i > 0; i >>= 1) {
real4 tpos1, tpos2;
tpos1.x = __shfl_down(minPos.x, i, TILE_SIZE);
tpos1.y = __shfl_down(minPos.y, i, TILE_SIZE);
tpos1.z = __shfl_down(minPos.z, i, TILE_SIZE);
tpos2.x = __shfl_down(maxPos.x, i, TILE_SIZE);
tpos2.y = __shfl_down(maxPos.y, i, TILE_SIZE);
tpos2.z = __shfl_down(maxPos.z, i, TILE_SIZE);
minPos.x = min(minPos.x, tpos1.x);
minPos.y = min(minPos.y, tpos1.y);
minPos.z = min(minPos.z, tpos1.z);
maxPos.x = max(maxPos.x, tpos2.x);
maxPos.y = max(maxPos.y, tpos2.y);
maxPos.z = max(maxPos.z, tpos2.z);
}
minPos.x = SHFL(minPos.x, 0);
minPos.y = SHFL(minPos.y, 0);
minPos.z = SHFL(minPos.z, 0);
maxPos.x = SHFL(maxPos.x, 0);
maxPos.y = SHFL(maxPos.y, 0);
maxPos.z = SHFL(maxPos.z, 0);
#endif
real4 blockSize = 0.5f*(maxPos-minPos);
real4 center = 0.5f*(maxPos+minPos);
center.w = 0;
real4 delta = tPos - center;
#ifdef USE_PERIODIC
APPLY_PERIODIC_TO_DELTA(delta)
APPLY_PERIODIC_TO_DELTA(delta)
#endif
center.w = max(center.w, delta.x*delta.x+delta.y*delta.y+delta.z*delta.z);
}
center.w = sqrt(center.w);
real tdelta = delta.x*delta.x+delta.y*delta.y+delta.z*delta.z;
real tcenter = max(center.w, tdelta);
for (int i = TILE_SIZE >> 1; i > 0; i >>= 1) {
real t = __shfl_down(tcenter, i, TILE_SIZE);
tcenter = max(tcenter, t);
}
if (indexInTile == 0) {
center.w = SQRT(tcenter);
blockBoundingBox[index] = blockSize;
blockCenter[index] = center;
sortedBlocks[index] = make_real2(blockSize.x+blockSize.y+blockSize.z, index);
index += blockDim.x*gridDim.x;
base = index*TILE_SIZE;
// blockSize.x+blockSize.y+blockSize.z has a distibution that looks like a normal distribution.
// This causes HipSort's buckets to have very non-uniform sizes, so a few very long buckets are
// sorted in global memory. -1/max(x, y, z) or -1/(x+y+z) have a "faster" distribution.
sortedBlocks[index] = make_real2(-RECIP(max(max(blockSize.x, blockSize.y), blockSize.z)), index);
}
if (blockIdx.x == 0 && threadIdx.x == 0)
rebuildNeighborList[0] = 0;
......@@ -65,7 +113,8 @@ extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, co
const real4* __restrict__ blockBoundingBox, real4* __restrict__ sortedBlockCenter,
real4* __restrict__ sortedBlockBoundingBox, const real4* __restrict__ posq, const real4* __restrict__ oldPositions,
unsigned int* __restrict__ interactionCount, int* __restrict__ rebuildNeighborList, bool forceRebuild) {
for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_BLOCKS; i += blockDim.x*gridDim.x) {
int i = threadIdx.x+blockIdx.x*blockDim.x;
if (i < NUM_BLOCKS) {
int index = (int) sortedBlock[i].y;
sortedBlockCenter[i] = blockCenter[index];
sortedBlockBoundingBox[i] = blockBoundingBox[index];
......@@ -74,11 +123,12 @@ extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, co
// Also check whether any atom has moved enough so that we really need to rebuild the neighbor list.
bool rebuild = forceRebuild;
for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_ATOMS; i += blockDim.x*gridDim.x) {
if (i < NUM_ATOMS) {
real4 delta = oldPositions[i]-posq[i];
if (delta.x*delta.x + delta.y*delta.y + delta.z*delta.z > 0.25f*PADDING*PADDING)
rebuild = true;
}
if (rebuild) {
rebuildNeighborList[0] = 1;
interactionCount[0] = 0;
......@@ -86,59 +136,59 @@ extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, co
}
}
__device__ int saveSinglePairs(int x, int* atoms, tileflags* flags, int length, unsigned int maxSinglePairs, unsigned int* singlePairCount, int2* singlePairs, int* sumBuffer, volatile int& pairStartIndex) {
// Record interactions that should be computed as single pairs rather than in blocks.
#if MAX_BITS_FOR_PAIRS > 0
const int indexInWarp = threadIdx.x%warpSize;
int sum = 0;
for (int i = indexInWarp; i < length; i += warpSize) {
int count = __TILEFLAGS_POPC(flags[i]);
sum += (count <= MAX_BITS_FOR_PAIRS ? count : 0);
}
for (int i = 1; i < warpSize; i *= 2) {
int n = SHFL(sum, indexInWarp - i);
if (indexInWarp >= i)
sum += n;
}
if (indexInWarp == warpSize - 1)
pairStartIndex = atomicAdd(singlePairCount,(unsigned int) sum);
SYNC_WARPS;
int prevSum = SHFL(sum, indexInWarp - 1);
int pairIndex = pairStartIndex + (indexInWarp > 0 ? prevSum : 0);
for (int i = indexInWarp; i < length; i += warpSize) {
int count = __TILEFLAGS_POPC(flags[i]);
if (count <= MAX_BITS_FOR_PAIRS && pairIndex+count <= maxSinglePairs) {
tileflags f = flags[i];
while (f != 0) {
singlePairs[pairIndex] = make_int2(atoms[i], x*TILE_SIZE+__TILEFLAGS_FFS(f)-1);
f &= f-1;
pairIndex++;
}
}
}
__device__ inline
void collectInteractions(unsigned int& interacts, float d, int bit) {
interacts |= (__float_as_uint(d) >> 31) << bit;
}
// Compact the remaining interactions.
const tileflags warpMask = (static_cast<tileflags>(1)<<indexInWarp)-1;
int numCompacted = 0;
for (int start = 0; start < length; start += warpSize) {
int i = start+indexInWarp;
int atom = atoms[i];
tileflags flag = flags[i];
bool include = (i < length && __TILEFLAGS_POPC(flags[i]) > MAX_BITS_FOR_PAIRS);
tileflags includeFlags = BALLOT(include);
if (include) {
int index = numCompacted+__TILEFLAGS_POPC(includeFlags&warpMask);
atoms[index] = atom;
flags[index] = flag;
}
numCompacted += __TILEFLAGS_POPC(includeFlags);
__device__ inline
void collectInteractions(unsigned int& interacts, double d, int bit) {
interacts |= ((unsigned int)__double2hiint(d) >> 31) << bit;
}
#else
// Simplified version that does not collect individual bits (they are not needed without single pairs),
// only sets the flag that there are any interactions.
__device__ inline
void collectInteractions(unsigned int& interacts, float d, int) {
interacts |= __float_as_uint(d) & (1 << 31);
}
__device__ inline
void collectInteractions(unsigned int& interacts, double d, int) {
interacts |= (unsigned int)__double2hiint(d) & (1 << 31);
}
#endif
#if !defined(USE_DOUBLE_PRECISION) && __has_builtin(__builtin_amdgcn_mfma_f32_4x4x1f32)
#define USE_MFMA
using vfloat = __attribute__((__vector_size__(4 * sizeof(float)))) float;
template<int BlockId>
inline __device__
void mfma4x4(const float4& pos1, const float4& pos2, const vfloat& c, unsigned int& interacts) {
vfloat d;
d = __builtin_amdgcn_mfma_f32_4x4x1f32(pos1.x, -pos2.x, c, 4, BlockId, 0);
d = __builtin_amdgcn_mfma_f32_4x4x1f32(pos1.y, -pos2.y, d, 4, BlockId, 0);
d = __builtin_amdgcn_mfma_f32_4x4x1f32(pos1.z, -pos2.z, d, 4, BlockId, 0);
d = __builtin_amdgcn_mfma_f32_4x4x1f32(pos1.w, 1.0f, d, 4, BlockId, 0);
#pragma unroll
for (int i = 0; i < 4; i++) {
collectInteractions(interacts, d[i], BlockId * 4 + i);
}
return numCompacted;
}
#endif
/**
* Compare the bounding boxes for each pair of atom blocks (comprised of warpSize atoms each), forming a tile. If the two
* Compare the bounding boxes for each pair of atom blocks (comprised of TILE_SIZE atoms each), forming a tile. If the two
* atom blocks are sufficiently far apart, mark them as non-interacting. There are two stages in the algorithm.
*
* STAGE 1:
......@@ -155,7 +205,7 @@ __device__ int saveSinglePairs(int x, int* atoms, tileflags* flags, int length,
* A fine grained atom block against interacting atoms neighbour list is constructed.
*
* The warp loops over atom blocks Y that were found to (possibly) interact with atom block X. Each thread
* in the warp loops over the warpSize atoms in X and compares their positions to one particular atom from block Y.
* in the warp loops over the TILE_SIZE atoms in X and compares their positions to one particular atom from block Y.
* If it finds one closer than the cutoff distance, the atom is added to the list of atoms interacting with block X.
* This continues until the buffer fills up, at which point the results are written to global memory.
*
......@@ -186,7 +236,7 @@ __device__ int saveSinglePairs(int x, int* atoms, tileflags* flags, int length,
* [in] rebuildNeighbourList - whether or not to execute this kernel
*
*/
extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
extern "C" __global__ __launch_bounds__(GROUP_SIZE) void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
unsigned int* __restrict__ interactionCount, int* __restrict__ interactingTiles, unsigned int* __restrict__ interactingAtoms,
int2* __restrict__ singlePairs, const real4* __restrict__ posq, unsigned int maxTiles, unsigned int maxSinglePairs,
unsigned int startBlockIndex, unsigned int numBlocks, real2* __restrict__ sortedBlocks, const real4* __restrict__ sortedBlockCenter,
......@@ -196,27 +246,29 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea
if (rebuildNeighborList[0] == 0)
return; // The neighbor list doesn't need to be rebuilt.
constexpr int tilesPerWarp = warpSize/TILE_SIZE;
constexpr int warpsPerBlock = GROUP_SIZE/warpSize;
const int indexInWarp = threadIdx.x%warpSize;
const int warpStart = threadIdx.x-indexInWarp;
const int totalWarps = blockDim.x*gridDim.x/warpSize;
const int warpIndex = (blockIdx.x*blockDim.x+threadIdx.x)/warpSize;
const tileflags warpMask = (static_cast<tileflags>(1)<<indexInWarp)-1;
__shared__ int workgroupBuffer[BUFFER_SIZE*(GROUP_SIZE/warpSize)];
__shared__ tileflags workgroupFlagsBuffer[BUFFER_SIZE*(GROUP_SIZE/warpSize)];
__shared__ int warpExclusions[MAX_EXCLUSIONS*(GROUP_SIZE/warpSize)];
__shared__ real4 posBuffer[GROUP_SIZE];
__shared__ volatile int workgroupTileIndex[GROUP_SIZE/warpSize];
__shared__ int worksgroupPairStartIndex[GROUP_SIZE/warpSize];
int* sumBuffer = (int*) posBuffer; // Reuse the same buffer to save memory
int* buffer = workgroupBuffer+BUFFER_SIZE*(warpStart/warpSize);
tileflags* flagsBuffer = workgroupFlagsBuffer+BUFFER_SIZE*(warpStart/warpSize);
int* exclusionsForX = warpExclusions+MAX_EXCLUSIONS*(warpStart/warpSize);
volatile int& tileStartIndex = workgroupTileIndex[warpStart/warpSize];
volatile int& pairStartIndex = worksgroupPairStartIndex[warpStart/warpSize];
const int indexInTile = threadIdx.x%TILE_SIZE;
const int tileInWarp = tilesPerWarp == 1 ? 0 : (threadIdx.x/TILE_SIZE)%tilesPerWarp;
const int warpInBlock = warpsPerBlock == 1 ? 0 : threadIdx.x/warpSize;
const int warpIndex = blockIdx.x*warpsPerBlock + (warpsPerBlock == 1 ? 0 : warpInBlock);
const warpflags warpMask = (static_cast<warpflags>(1)<<indexInWarp)-1;
__shared__ int workgroupBuffer[BUFFER_SIZE*warpsPerBlock];
__shared__ real4 workgroupPosBuffer[TILE_SIZE*warpsPerBlock];
__shared__ int workgroupExclusions[MAX_EXCLUSIONS*warpsPerBlock];
__shared__ int workgroupBlock2Buffer[(warpSize+1)*warpsPerBlock];
int* buffer = workgroupBuffer+BUFFER_SIZE*warpInBlock;
real4* posBuffer = workgroupPosBuffer+TILE_SIZE*warpInBlock;
int* exclusionsForX = workgroupExclusions+MAX_EXCLUSIONS*warpInBlock;
int* block2Buffer = workgroupBlock2Buffer+(warpSize+1)*warpInBlock;
// Loop over blocks.
for (int block1 = startBlockIndex+warpIndex; block1 < startBlockIndex+numBlocks; block1 += totalWarps) {
int block1 = startBlockIndex+warpIndex/NUM_TILES_IN_BATCH;
if (block1 < startBlockIndex+numBlocks) {
// Load data for this block. Note that all threads in a warp are processing the same block.
real2 sortedKey = sortedBlocks[block1];
......@@ -224,7 +276,7 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea
real4 blockCenterX = sortedBlockCenter[block1];
real4 blockSizeX = sortedBlockBoundingBox[block1];
int neighborsInBuffer = 0;
real4 pos1 = posq[x*TILE_SIZE+indexInWarp];
real4 pos1 = posq[x*TILE_SIZE+indexInTile];
#ifdef USE_PERIODIC
const bool singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= PADDED_CUTOFF &&
0.5f*periodicBoxSize.y-blockSizeX.y >= PADDED_CUTOFF &&
......@@ -237,64 +289,86 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea
}
#endif
pos1.w = 0.5f * (pos1.x * pos1.x + pos1.y * pos1.y + pos1.z * pos1.z);
posBuffer[threadIdx.x] = pos1;
if (tileInWarp == 0) {
posBuffer[indexInTile] = pos1;
}
// Load exclusion data for block x.
const int exclusionStart = exclusionRowIndices[x];
const int exclusionEnd = exclusionRowIndices[x+1];
const int numExclusions = exclusionEnd-exclusionStart;
#pragma unroll 1
for (int j = indexInWarp; j < numExclusions; j += warpSize)
exclusionsForX[j] = exclusionIndices[exclusionStart+j];
if (MAX_EXCLUSIONS > warpSize)
__syncthreads();
// Loop over atom blocks to search for neighbors. The threads in a warp compare block1 against warpSize
// other blocks in parallel.
for (int block2Base = block1+1; block2Base < NUM_BLOCKS; block2Base += warpSize) {
// For small systems multiple warps (NUM_TILES_IN_BATCH = 4, 2...) process one block1 reducing the overall
// duration of the kernel because first blocks block1 have to process more block2 blocks so most of compute
// units are idle at the end of the kernel (the kernel works on the upper triangle of
// the NUM_BLOCKS x NUM_BLOCKS matrix).
int block2Count = 0;
// Load blocks from addresses aligned by warpSize for faster loading from sortedBlockCenter and sortedBlockBoundingBox.
for (int block2Base = ((block1+1)/warpSize + warpIndex%NUM_TILES_IN_BATCH)*warpSize; block2Base < NUM_BLOCKS; block2Base += warpSize*NUM_TILES_IN_BATCH) {
const bool lastIteration = block2Base + warpSize*NUM_TILES_IN_BATCH >= NUM_BLOCKS;
int block2 = block2Base+indexInWarp;
bool includeBlock2 = (block2 < NUM_BLOCKS);
bool includeBlock2 = (block1 < block2 && block2 < NUM_BLOCKS);
block2 = includeBlock2 ? block2 : block1;
bool forceInclude = false;
if (includeBlock2) {
real4 blockCenterY = sortedBlockCenter[block2];
real4 blockSizeY = sortedBlockBoundingBox[block2];
real4 blockDelta = blockCenterX-blockCenterY;
real4 blockCenterY = sortedBlockCenter[block2];
real4 blockDelta = blockCenterX-blockCenterY;
#ifdef USE_PERIODIC
APPLY_PERIODIC_TO_DELTA(blockDelta)
APPLY_PERIODIC_TO_DELTA(blockDelta)
#endif
includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < (PADDED_CUTOFF+blockCenterX.w+blockCenterY.w)*(PADDED_CUTOFF+blockCenterX.w+blockCenterY.w));
#ifndef TRICLINIC
if (!lastIteration && __ballot(includeBlock2) == 0)
continue;
#endif
includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < (PADDED_CUTOFF+blockCenterX.w+blockCenterY.w)*(PADDED_CUTOFF+blockCenterX.w+blockCenterY.w));
blockDelta.x = max(0.0f, fabs(blockDelta.x)-blockSizeX.x-blockSizeY.x);
blockDelta.y = max(0.0f, fabs(blockDelta.y)-blockSizeX.y-blockSizeY.y);
blockDelta.z = max(0.0f, fabs(blockDelta.z)-blockSizeX.z-blockSizeY.z);
includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < PADDED_CUTOFF_SQUARED);
real4 blockSizeY = sortedBlockBoundingBox[block2];
blockDelta.x = max(0.0f, fabs(blockDelta.x)-blockSizeX.x-blockSizeY.x);
blockDelta.y = max(0.0f, fabs(blockDelta.y)-blockSizeX.y-blockSizeY.y);
blockDelta.z = max(0.0f, fabs(blockDelta.z)-blockSizeX.z-blockSizeY.z);
includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < PADDED_CUTOFF_SQUARED);
#ifdef TRICLINIC
// The calculation to find the nearest periodic copy is only guaranteed to work if the nearest copy is less than half a box width away.
// If there's any possibility we might have missed it, do a detailed check.
// The calculation to find the nearest periodic copy is only guaranteed to work if the nearest copy is less than half a box width away.
// If there's any possibility we might have missed it, do a detailed check.
if (periodicBoxSize.z/2-blockSizeX.z-blockSizeY.z < PADDED_CUTOFF || periodicBoxSize.y/2-blockSizeX.y-blockSizeY.y < PADDED_CUTOFF)
includeBlock2 = forceInclude = true;
if (periodicBoxSize.z/2-blockSizeX.z-blockSizeY.z < PADDED_CUTOFF || periodicBoxSize.y/2-blockSizeX.y-blockSizeY.y < PADDED_CUTOFF)
includeBlock2 = forceInclude = true;
#endif
if (includeBlock2) {
int y = (int) sortedBlocks[block2].y;
for (int k = 0; k < numExclusions; k++)
includeBlock2 &= (exclusionsForX[k] != y);
}
// Collect any blocks we identified as potentially containing neighbors.
warpflags includeBlockFlags = __ballot(includeBlock2);
if (includeBlock2) {
int index = block2Count + warpPopc(includeBlockFlags&warpMask);
block2Buffer[index] = (block2 << 1) | (forceInclude ? 1 : 0);
}
block2Count += warpPopc(includeBlockFlags);
// Loop over the collected candidates (each warp processes 2 blocks on CDNA or 1 block on RDNA).
// Process even number of blocks on CDNA so both half-warps have work to do (except for
// the last iteration of the for-block2Base loop when the left-over must be processed).
// Loop over any blocks we identified as potentially containing neighbors.
const int block2ToProcess = lastIteration ? block2Count : block2Count/tilesPerWarp*tilesPerWarp;
for (int block2Index = 0; block2Index < block2ToProcess; block2Index += tilesPerWarp) {
bool includeBlock2 = block2Index + tileInWarp < block2Count;
const int b = block2Buffer[min(block2Index + tileInWarp, block2Count - 1)];
const bool forceInclude = b & 1;
const int block2 = b >> 1;
int y = (int) sortedBlocks[block2].y;
tileflags includeBlockFlags = BALLOT(includeBlock2);
tileflags forceIncludeFlags = BALLOT(forceInclude);
while (includeBlockFlags != 0) {
int i = __TILEFLAGS_FFS(includeBlockFlags)-1;
includeBlockFlags &= includeBlockFlags-1;
forceInclude = (forceIncludeFlags>>i) & 1;
int y = (int) sortedBlocks[block2Base+i].y;
#pragma unroll 1
for (int k = indexInTile; k < numExclusions; k += TILE_SIZE)
includeBlock2 &= (exclusionsForX[k] != y);
includeBlock2 = BALLOT(!includeBlock2) == 0;
// Check each atom in block Y for interactions.
int atom2 = y*TILE_SIZE+indexInWarp;
int atom2 = y*TILE_SIZE+indexInTile;
real4 pos2 = posq[atom2];
#ifdef USE_PERIODIC
if (singlePeriodicCopy) {
......@@ -302,92 +376,145 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea
}
#endif
pos2.w = 0.5f * (pos2.x * pos2.x + pos2.y * pos2.y + pos2.z * pos2.z);
real4 blockCenterY = sortedBlockCenter[block2Base+i];
real3 atomDelta = trimTo3(posBuffer[warpStart+indexInWarp])-trimTo3(blockCenterY);
real4 blockCenterY = sortedBlockCenter[block2];
real3 atomDelta = trimTo3(pos1)-trimTo3(blockCenterY);
#ifdef USE_PERIODIC
APPLY_PERIODIC_TO_DELTA(atomDelta)
#endif
tileflags atomFlags = BALLOT(forceInclude || atomDelta.x*atomDelta.x+atomDelta.y*atomDelta.y+atomDelta.z*atomDelta.z < (PADDED_CUTOFF+blockCenterY.w)*(PADDED_CUTOFF+blockCenterY.w));
tileflags interacts = 0;
if (atom2 < NUM_ATOMS && atomFlags != 0) {
int first = __TILEFLAGS_FFS(atomFlags)-1;
int last = warpSize-__TILEFLAGS_CLZ(atomFlags);
// The condition `posj.w + pos2.w - posj.x*pos2.x - posj.y*pos2.y - posj.z*pos2.z < 0.5f * PADDED_CUTOFF_SQUARED` is expressed as
// `posj.x*pos2.x - posj.y*pos2.y - posj.z*pos2.z - posj.w - 0.5f * PADDED_CUTOFF_SQUARED - pos2.w` and computed using fma
// (it saves 1 instruction).
// Sign bit is used directly instead of `halfDist2 < 0.5f * PADDED_CUTOFF_SQUARED ? 1<<j : 0`.
#ifdef USE_PERIODIC
if (!singlePeriodicCopy) {
for (int j = first; j < last; j++) {
real3 delta = trimTo3(pos2)-trimTo3(posBuffer[warpStart+j]);
APPLY_PERIODIC_TO_DELTA(delta)
interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED ? static_cast<tileflags>(1)<<j : 0);
}
if (!singlePeriodicCopy) {
while (atomFlags) {
int j = __ffs(atomFlags)-1;
atomFlags = atomFlags ^ (static_cast<tileflags>(1) << j);
real3 delta = trimTo3(pos2)-trimTo3(posBuffer[j]);
APPLY_PERIODIC_TO_DELTA(delta)
real d = delta.x*delta.x+delta.y*delta.y+delta.z*delta.z - PADDED_CUTOFF_SQUARED;
collectInteractions(interacts, d, j);
}
else {
}
else {
#endif
#pragma unroll
for (int j = 0; j < warpSize; j++) {
real4 posj = posBuffer[warpStart+j];
real halfDist2 = posj.w + pos2.w - posj.x*pos2.x - posj.y*pos2.y - posj.z*pos2.z;
interacts |= (halfDist2 < 0.5f * PADDED_CUTOFF_SQUARED ? static_cast<tileflags>(1)<<j : 0);
}
#ifdef USE_PERIODIC
const real lim = 0.5f * PADDED_CUTOFF_SQUARED - pos2.w;
#if defined(USE_MFMA)
const vfloat c = { -lim, -lim, -lim, -lim };
mfma4x4<0>(pos1, pos2, c, interacts);
mfma4x4<1>(pos1, pos2, c, interacts);
mfma4x4<2>(pos1, pos2, c, interacts);
mfma4x4<3>(pos1, pos2, c, interacts);
mfma4x4<4>(pos1, pos2, c, interacts);
mfma4x4<5>(pos1, pos2, c, interacts);
mfma4x4<6>(pos1, pos2, c, interacts);
mfma4x4<7>(pos1, pos2, c, interacts);
#else
while (atomFlags) {
int j = __ffs(atomFlags)-1;
atomFlags = atomFlags ^ (static_cast<tileflags>(1) << j);
real4 posj = posBuffer[j];
real d = fma(-posj.x, pos2.x, fma(-posj.y, pos2.y, fma(-posj.z, pos2.z, posj.w - lim)));
collectInteractions(interacts, d, j);
}
#endif
#ifdef USE_PERIODIC
}
#endif
if (atom2 >= NUM_ATOMS || !includeBlock2) {
interacts = 0;
}
#if MAX_BITS_FOR_PAIRS > 0
const unsigned int interactCount = __popc(interacts);
// Record interactions that should be computed as single pairs rather than in blocks.
const bool storeAsSinglePair = interactCount > 0 && interactCount <= MAX_BITS_FOR_PAIRS;
if (__ballot(storeAsSinglePair)) {
unsigned int sum = 0;
unsigned int prevSum = 0;
for (int i = 1; i <= MAX_BITS_FOR_PAIRS; i++) {
warpflags b = __ballot(interactCount == i);
sum += warpPopc(b) * i;
prevSum += warpPopc(b&warpMask) * i;
}
unsigned int pairStartIndex = 0;
if (indexInWarp == 0)
pairStartIndex = atomicAdd(&interactionCount[1], sum);
pairStartIndex = __shfl(pairStartIndex, 0);
unsigned int pairIndex = pairStartIndex + prevSum;
if (storeAsSinglePair && pairIndex+interactCount <= maxSinglePairs) {
while (interacts != 0) {
int j = __ffs(interacts)-1;
singlePairs[pairIndex] = make_int2(atom2, x*TILE_SIZE+j);
interacts = interacts ^ (static_cast<tileflags>(1) << j);
pairIndex++;
}
}
}
#else
const unsigned int interactCount = interacts;
#endif
// Add any interacting atoms to the buffer.
tileflags includeAtomFlags = BALLOT(interacts);
if (interacts) {
int index = neighborsInBuffer+__TILEFLAGS_POPC(includeAtomFlags&warpMask);
warpflags includeAtomFlags = __ballot(interactCount > MAX_BITS_FOR_PAIRS);
if (interactCount > MAX_BITS_FOR_PAIRS) {
int index = neighborsInBuffer+warpPopc(includeAtomFlags&warpMask);
buffer[index] = atom2;
flagsBuffer[index] = interacts;
}
neighborsInBuffer += __TILEFLAGS_POPC(includeAtomFlags);
if (neighborsInBuffer > BUFFER_SIZE-TILE_SIZE) {
neighborsInBuffer += warpPopc(includeAtomFlags);
if (neighborsInBuffer > BUFFER_SIZE-warpSize) {
// Store the new tiles to memory.
#if MAX_BITS_FOR_PAIRS > 0
neighborsInBuffer = saveSinglePairs(x, buffer, flagsBuffer, neighborsInBuffer, maxSinglePairs, &interactionCount[1], singlePairs, sumBuffer+warpStart, pairStartIndex);
#endif
int tilesToStore = neighborsInBuffer/TILE_SIZE;
if (tilesToStore > 0) {
if (indexInWarp == 0)
tileStartIndex = atomicAdd(&interactionCount[0], tilesToStore);
int newTileStartIndex = tileStartIndex;
if (newTileStartIndex+tilesToStore <= maxTiles) {
if (indexInWarp < tilesToStore)
interactingTiles[newTileStartIndex+indexInWarp] = x;
for (int j = 0; j < tilesToStore; j++)
interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = buffer[indexInWarp+j*TILE_SIZE];
}
buffer[indexInWarp] = buffer[indexInWarp+TILE_SIZE*tilesToStore];
neighborsInBuffer -= TILE_SIZE*tilesToStore;
unsigned int tilesToStore = neighborsInBuffer/warpSize*tilesPerWarp;
unsigned int tileStartIndex = 0;
if (indexInWarp == 0)
tileStartIndex = atomicAdd(&interactionCount[0], tilesToStore);
unsigned int newTileStartIndex = __shfl(tileStartIndex, 0);
if (newTileStartIndex+tilesToStore <= maxTiles) {
if (indexInWarp < tilesToStore)
interactingTiles[newTileStartIndex+indexInWarp] = x;
for (int j = 0; j < tilesToStore/tilesPerWarp; j++)
interactingAtoms[newTileStartIndex*TILE_SIZE+j*warpSize+indexInWarp] = buffer[j*warpSize+indexInWarp];
}
if (indexInWarp+TILE_SIZE*tilesToStore < BUFFER_SIZE)
buffer[indexInWarp] = buffer[indexInWarp+TILE_SIZE*tilesToStore];
neighborsInBuffer -= TILE_SIZE*tilesToStore;
}
}
// Move not processed blocks to the head of block2Buffer.
if (indexInWarp < block2Count - block2ToProcess)
block2Buffer[indexInWarp] = block2Buffer[block2ToProcess + indexInWarp];
block2Count = block2Count - block2ToProcess;
}
// If we have a partially filled buffer, store it to memory.
#if MAX_BITS_FOR_PAIRS > 0
if (neighborsInBuffer > warpSize)
neighborsInBuffer = saveSinglePairs(x, buffer, flagsBuffer, neighborsInBuffer, maxSinglePairs, &interactionCount[1], singlePairs, sumBuffer+warpStart, pairStartIndex);
#endif
if (neighborsInBuffer > 0) {
int tilesToStore = (neighborsInBuffer+TILE_SIZE-1)/TILE_SIZE;
unsigned int tilesToStore = (neighborsInBuffer+TILE_SIZE-1)/TILE_SIZE;
unsigned int tileStartIndex = 0;
if (indexInWarp == 0)
tileStartIndex = atomicAdd(&interactionCount[0], tilesToStore);
int newTileStartIndex = tileStartIndex;
unsigned int newTileStartIndex = __shfl(tileStartIndex, 0);
if (newTileStartIndex+tilesToStore <= maxTiles) {
if (indexInWarp < tilesToStore)
interactingTiles[newTileStartIndex+indexInWarp] = x;
for (int j = 0; j < tilesToStore; j++)
interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = (indexInWarp+j*TILE_SIZE < neighborsInBuffer ? buffer[indexInWarp+j*TILE_SIZE] : NUM_ATOMS);
for (int j = 0; j <= tilesToStore/tilesPerWarp; j++) {
if (j*warpSize+indexInWarp < tilesToStore*TILE_SIZE)
interactingAtoms[newTileStartIndex*TILE_SIZE+j*warpSize+indexInWarp] = (j*warpSize+indexInWarp < neighborsInBuffer ? buffer[j*warpSize+indexInWarp] : PADDED_NUM_ATOMS);
}
}
}
}
// Record the positions the neighbor list is based on.
for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_ATOMS; i += blockDim.x*gridDim.x)
int i = threadIdx.x+blockIdx.x*GROUP_SIZE;
if (i < NUM_ATOMS) {
oldPositions[i] = posq[i];
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment