Unverified Commit a96534c1 authored by Anton Gorenko's avatar Anton Gorenko
Browse files

Optimize findInteractingBlocks

Optimize findBlocksWithInteractions

* Replace volatile shared mem accesses with shuffles;
* Add NUM_TILES_IN_BATCH for processing block1 by multiple warps
  (for small systems);
* Cherry-pick missing changes from .cu;
* Tune MAX_BITS_FOR_PAIRS depending on device and the system size;
* Store single pairs immediately (if there are any), this allows not to
  store flags to shared memory and filter buffer and flagsBuffer after
  saving single pairs;
* Use fma explicitly and sign bit for better device code;
* Use CDNA's MFMA with singe/mixed precision;
* On CDNA the coarse grained stage processes warpSize blocks for
  one block1, the fine grained stage checks atoms of two block2 vs atoms
  of the same block1, singlePairs and interactingAtoms are also stored
  by warps, not half-warps;

Optimize findBlockBounds

* Use shuffles;
* Use executeKernelFlat;
* Process 2 tiles per warp 64 on CDNA;
* Use more uniformly distributed keys when sorting blocks;

Use compareInt2LargeSIMD when tile size < SIMD width

Fix exclusion tiles sorting on AMD CDNA (64 threads per wave)

    The nonbonded kernel uses USE_NEIGHBOR_LIST (useNeighborList)
    so host code also must check it instead of useCutoff.

    See also https://github.com/openmm/openmm/issues/3462
parent 67f5644d
...@@ -284,7 +284,17 @@ public: ...@@ -284,7 +284,17 @@ public:
* @param blockSize the size of each thread block to use * @param blockSize the size of each thread block to use
* @param sharedSize the amount of dynamic shared memory to allocated for the kernel, in bytes * @param sharedSize the amount of dynamic shared memory to allocated for the kernel, in bytes
*/ */
void executeKernel(hipFunction_t kernel, void** arguments, int workUnits, int blockSize = -1, unsigned int sharedSize = 0); void executeKernel(hipFunction_t kernel, void** arguments, int threads, int blockSize = -1, unsigned int sharedSize = 0);
/**
* Execute a kernel with full grid.
*
* @param kernel the kernel to execute
* @param arguments an array of pointers to the kernel arguments
* @param threads the total number of threads that should be used
* @param blockSize the size of each thread block to use
* @param sharedSize the amount of dynamic shared memory to allocated for the kernel, in bytes
*/
void executeKernelFlat(hipFunction_t kernel, void** arguments, int threads, int blockSize = -1, unsigned int sharedSize = 0);
/** /**
* Compute the largest thread block size that can be used for a kernel that requires a particular amount of * Compute the largest thread block size that can be used for a kernel that requires a particular amount of
* shared memory per thread. * shared memory per thread.
......
...@@ -349,7 +349,8 @@ private: ...@@ -349,7 +349,8 @@ private:
std::map<int, std::string> groupKernelSource; std::map<int, std::string> groupKernelSource;
double lastCutoff; double lastCutoff;
bool useCutoff, usePeriodic, anyExclusions, usePadding, forceRebuildNeighborList, canUsePairList; bool useCutoff, usePeriodic, anyExclusions, usePadding, forceRebuildNeighborList, canUsePairList;
int startTileIndex, startBlockIndex, numBlocks, maxTiles, maxSinglePairs, maxExclusions, numForceThreadBlocks, forceThreadBlockSize, numAtoms, groupFlags; int startTileIndex, startBlockIndex, numBlocks, maxTiles, maxSinglePairs, numTilesInBatch, maxExclusions;
int numForceThreadBlocks, forceThreadBlockSize, findInteractingBlocksThreadBlockSize, numAtoms, groupFlags;
long long numTiles; long long numTiles;
std::string kernelSource; std::string kernelSource;
}; };
......
...@@ -73,6 +73,7 @@ HipNonbondedUtilities::HipNonbondedUtilities(HipContext& context) : context(cont ...@@ -73,6 +73,7 @@ HipNonbondedUtilities::HipNonbondedUtilities(HipContext& context) : context(cont
CHECK_RESULT(hipHostMalloc((void**) &pinnedCountBuffer, 2*sizeof(unsigned int), hipHostMallocPortable)); CHECK_RESULT(hipHostMalloc((void**) &pinnedCountBuffer, 2*sizeof(unsigned int), hipHostMallocPortable));
numForceThreadBlocks = 5*4*context.getMultiprocessors(); numForceThreadBlocks = 5*4*context.getMultiprocessors();
forceThreadBlockSize = 64; forceThreadBlockSize = 64;
findInteractingBlocksThreadBlockSize = context.getSIMDWidth();
setKernelSource(HipKernelSources::nonbonded); setKernelSource(HipKernelSources::nonbonded);
} }
...@@ -170,6 +171,20 @@ static bool compareInt2(int2 a, int2 b) { ...@@ -170,6 +171,20 @@ static bool compareInt2(int2 a, int2 b) {
return ((a.y < b.y) || (a.y == b.y && a.x < b.x)); return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
} }
static bool compareInt2LargeSIMD(int2 a, int2 b) {
// This version is used on devices with SIMD width greater than tile size. It puts diagonal tiles before off-diagonal
// ones to reduce thread divergence.
if (a.x == a.y) {
if (b.x == b.y)
return (a.x < b.x);
return true;
}
if (b.x == b.y)
return false;
return ((a.y < b.y) || (a.y == b.y && a.x < b.x));
}
void HipNonbondedUtilities::initialize(const System& system) { void HipNonbondedUtilities::initialize(const System& system) {
string errorMessage = "Error initializing nonbonded utilities"; string errorMessage = "Error initializing nonbonded utilities";
if (atomExclusions.size() == 0) { if (atomExclusions.size() == 0) {
...@@ -201,7 +216,7 @@ void HipNonbondedUtilities::initialize(const System& system) { ...@@ -201,7 +216,7 @@ void HipNonbondedUtilities::initialize(const System& system) {
vector<int2> exclusionTilesVec; vector<int2> exclusionTilesVec;
for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter) for (set<pair<int, int> >::const_iterator iter = tilesWithExclusions.begin(); iter != tilesWithExclusions.end(); ++iter)
exclusionTilesVec.push_back(make_int2(iter->first, iter->second)); exclusionTilesVec.push_back(make_int2(iter->first, iter->second));
sort(exclusionTilesVec.begin(), exclusionTilesVec.end(), compareInt2); sort(exclusionTilesVec.begin(), exclusionTilesVec.end(), context.getSIMDWidth() <= 32 || !useNeighborList ? compareInt2 : compareInt2LargeSIMD);
exclusionTiles.initialize<int2>(context, exclusionTilesVec.size(), "exclusionTiles"); exclusionTiles.initialize<int2>(context, exclusionTilesVec.size(), "exclusionTiles");
exclusionTiles.upload(exclusionTilesVec); exclusionTiles.upload(exclusionTilesVec);
map<pair<int, int>, int> exclusionTileMap; map<pair<int, int>, int> exclusionTileMap;
...@@ -266,6 +281,8 @@ void HipNonbondedUtilities::initialize(const System& system) { ...@@ -266,6 +281,8 @@ void HipNonbondedUtilities::initialize(const System& system) {
if (maxTiles < 1) if (maxTiles < 1)
maxTiles = 1; maxTiles = 1;
maxSinglePairs = 5*numAtoms; maxSinglePairs = 5*numAtoms;
// HIP-TODO: This may require tuning
numTilesInBatch = numAtomBlocks < 2000 ? 4 : 1;
interactingTiles.initialize<int>(context, maxTiles, "interactingTiles"); interactingTiles.initialize<int>(context, maxTiles, "interactingTiles");
interactingAtoms.initialize<int>(context, HipContext::TileSize*maxTiles, "interactingAtoms"); interactingAtoms.initialize<int>(context, HipContext::TileSize*maxTiles, "interactingAtoms");
interactionCount.initialize<unsigned int>(context, 2, "interactionCount"); interactionCount.initialize<unsigned int>(context, 2, "interactionCount");
...@@ -393,10 +410,10 @@ void HipNonbondedUtilities::prepareInteractions(int forceGroups) { ...@@ -393,10 +410,10 @@ void HipNonbondedUtilities::prepareInteractions(int forceGroups) {
if (lastCutoff != kernels.cutoffDistance) if (lastCutoff != kernels.cutoffDistance)
forceRebuildNeighborList = true; forceRebuildNeighborList = true;
context.executeKernel(kernels.findBlockBoundsKernel, &findBlockBoundsArgs[0], context.getNumAtoms()); context.executeKernelFlat(kernels.findBlockBoundsKernel, &findBlockBoundsArgs[0], context.getPaddedNumAtoms(), context.getSIMDWidth());
blockSorter->sort(sortedBlocks); blockSorter->sort(sortedBlocks);
context.executeKernel(kernels.sortBoxDataKernel, &sortBoxDataArgs[0], context.getNumAtoms()); context.executeKernelFlat(kernels.sortBoxDataKernel, &sortBoxDataArgs[0], context.getNumAtoms(), 64);
context.executeKernel(kernels.findInteractingBlocksKernel, &findInteractingBlocksArgs[0], context.getNumAtoms(), 256); context.executeKernelFlat(kernels.findInteractingBlocksKernel, &findInteractingBlocksArgs[0], context.getNumAtomBlocks() * context.getSIMDWidth() * numTilesInBatch, findInteractingBlocksThreadBlockSize);
forceRebuildNeighborList = false; forceRebuildNeighborList = false;
lastCutoff = kernels.cutoffDistance; lastCutoff = kernels.cutoffDistance;
interactionCount.download(pinnedCountBuffer, false); interactionCount.download(pinnedCountBuffer, false);
...@@ -488,6 +505,7 @@ void HipNonbondedUtilities::createKernelsForGroups(int groups) { ...@@ -488,6 +505,7 @@ void HipNonbondedUtilities::createKernelsForGroups(int groups) {
defines["TILE_SIZE"] = context.intToString(HipContext::TileSize); defines["TILE_SIZE"] = context.intToString(HipContext::TileSize);
defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks()); defines["NUM_BLOCKS"] = context.intToString(context.getNumAtomBlocks());
defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms()); defines["NUM_ATOMS"] = context.intToString(context.getNumAtoms());
defines["PADDED_NUM_ATOMS"] = context.intToString(context.getPaddedNumAtoms());
defines["PADDING"] = context.doubleToString(paddedCutoff-cutoff); defines["PADDING"] = context.doubleToString(paddedCutoff-cutoff);
defines["PADDED_CUTOFF"] = context.doubleToString(paddedCutoff); defines["PADDED_CUTOFF"] = context.doubleToString(paddedCutoff);
defines["PADDED_CUTOFF_SQUARED"] = context.doubleToString(paddedCutoff*paddedCutoff); defines["PADDED_CUTOFF_SQUARED"] = context.doubleToString(paddedCutoff*paddedCutoff);
...@@ -497,7 +515,33 @@ void HipNonbondedUtilities::createKernelsForGroups(int groups) { ...@@ -497,7 +515,33 @@ void HipNonbondedUtilities::createKernelsForGroups(int groups) {
if (context.getBoxIsTriclinic()) if (context.getBoxIsTriclinic())
defines["TRICLINIC"] = "1"; defines["TRICLINIC"] = "1";
defines["MAX_EXCLUSIONS"] = context.intToString(maxExclusions); defines["MAX_EXCLUSIONS"] = context.intToString(maxExclusions);
defines["MAX_BITS_FOR_PAIRS"] = (canUsePairList ? "2" : "0"); int maxBits = 0;
if (canUsePairList) {
if (context.getUseDoublePrecision()) {
maxBits = 4;
}
else {
if (context.getSIMDWidth() > 32) {
// CDNA
if (context.getNumAtoms() < 100000)
maxBits = 4;
else // Large systems
maxBits = 0;
}
else {
// RDNA
if (context.getNumAtoms() < 100000)
maxBits = 4;
else if (context.getNumAtoms() < 500000)
maxBits = 2;
else // Very large systems
maxBits = 0;
}
}
}
defines["MAX_BITS_FOR_PAIRS"] = context.intToString(maxBits);
defines["NUM_TILES_IN_BATCH"] = context.intToString(numTilesInBatch);
defines["GROUP_SIZE"] = context.intToString(findInteractingBlocksThreadBlockSize);
hipModule_t interactingBlocksProgram = context.createModule(HipKernelSources::vectorOps+HipKernelSources::findInteractingBlocks, defines); hipModule_t interactingBlocksProgram = context.createModule(HipKernelSources::vectorOps+HipKernelSources::findInteractingBlocks, defines);
kernels.findBlockBoundsKernel = context.getKernel(interactingBlocksProgram, "findBlockBounds"); kernels.findBlockBoundsKernel = context.getKernel(interactingBlocksProgram, "findBlockBounds");
kernels.sortBoxDataKernel = context.getKernel(interactingBlocksProgram, "sortBoxData"); kernels.sortBoxDataKernel = context.getKernel(interactingBlocksProgram, "sortBoxData");
......
#define GROUP_SIZE 256
#define BUFFER_SIZE 256 #define BUFFER_SIZE 256
__device__ inline int __TILEFLAGS_FFS(unsigned long long int x) { #if defined(AMD_RDNA)
return __ffsll(x);
} typedef unsigned int warpflags;
__device__ inline int __TILEFLAGS_CLZ(unsigned long long int x) {
return __clzll(x); __device__ inline int warpPopc(warpflags x) {
return __popc(x);
} }
__device__ inline int __TILEFLAGS_POPC(unsigned long long int x) {
#else
typedef unsigned long long warpflags;
__device__ inline int warpPopc(warpflags x) {
return __popcll(x); return __popcll(x);
} }
#endif
/** /**
* Find a bounding box for the atoms in each block. * Find a bounding box for the atoms in each block.
*/ */
extern "C" __global__ void findBlockBounds(int numAtoms, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ, extern "C" __global__ void findBlockBounds(int numAtoms, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
const real4* __restrict__ posq, real4* __restrict__ blockCenter, real4* __restrict__ blockBoundingBox, int* __restrict__ rebuildNeighborList, const real4* __restrict__ posq, real4* __restrict__ blockCenter, real4* __restrict__ blockBoundingBox, int* __restrict__ rebuildNeighborList,
real2* __restrict__ sortedBlocks) { real2* __restrict__ sortedBlocks) {
int index = blockIdx.x*blockDim.x+threadIdx.x; const int indexInTile = threadIdx.x%TILE_SIZE;
int base = index*TILE_SIZE; const int index = warpSize == TILE_SIZE ? blockIdx.x : (blockIdx.x*(warpSize/TILE_SIZE) + threadIdx.x/TILE_SIZE);
while (base < numAtoms) { const int base = index * TILE_SIZE;
real4 pos = posq[base]; if (base >= numAtoms)
#ifdef USE_PERIODIC return;
APPLY_PERIODIC_TO_POS(pos)
#endif real4 tPos = posq[base + indexInTile < numAtoms ? base + indexInTile : 0];
real4 minPos = pos;
real4 maxPos = pos;
int last = min(base+TILE_SIZE, numAtoms);
for (int i = base+1; i < last; i++) {
pos = posq[i];
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
real4 center = 0.5f*(maxPos+minPos); real4 pos;
APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center) pos.x = SHFL(tPos.x, 0);
#endif pos.y = SHFL(tPos.y, 0);
minPos = make_real4(min(minPos.x,pos.x), min(minPos.y,pos.y), min(minPos.z,pos.z), 0); pos.z = SHFL(tPos.z, 0);
maxPos = make_real4(max(maxPos.x,pos.x), max(maxPos.y,pos.y), max(maxPos.z,pos.z), 0); APPLY_PERIODIC_TO_POS(pos)
}
real4 blockSize = 0.5f*(maxPos-minPos); real4 minPos = pos;
real4 maxPos = pos;
for (int i = 1; i < TILE_SIZE; i++) {
pos.x = SHFL(tPos.x, i);
pos.y = SHFL(tPos.y, i);
pos.z = SHFL(tPos.z, i);
real4 center = 0.5f*(maxPos+minPos); real4 center = 0.5f*(maxPos+minPos);
center.w = 0; APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)
for (int i = base; i < last; i++) { minPos = make_real4(min(minPos.x,pos.x), min(minPos.y,pos.y), min(minPos.z,pos.z), 0);
pos = posq[i]; maxPos = make_real4(max(maxPos.x,pos.x), max(maxPos.y,pos.y), max(maxPos.z,pos.z), 0);
real4 delta = posq[i]-center; }
#else
real4 minPos = tPos;
real4 maxPos = tPos;
for (int i = TILE_SIZE >> 1; i > 0; i >>= 1) {
real4 tpos1, tpos2;
tpos1.x = __shfl_down(minPos.x, i, TILE_SIZE);
tpos1.y = __shfl_down(minPos.y, i, TILE_SIZE);
tpos1.z = __shfl_down(minPos.z, i, TILE_SIZE);
tpos2.x = __shfl_down(maxPos.x, i, TILE_SIZE);
tpos2.y = __shfl_down(maxPos.y, i, TILE_SIZE);
tpos2.z = __shfl_down(maxPos.z, i, TILE_SIZE);
minPos.x = min(minPos.x, tpos1.x);
minPos.y = min(minPos.y, tpos1.y);
minPos.z = min(minPos.z, tpos1.z);
maxPos.x = max(maxPos.x, tpos2.x);
maxPos.y = max(maxPos.y, tpos2.y);
maxPos.z = max(maxPos.z, tpos2.z);
}
minPos.x = SHFL(minPos.x, 0);
minPos.y = SHFL(minPos.y, 0);
minPos.z = SHFL(minPos.z, 0);
maxPos.x = SHFL(maxPos.x, 0);
maxPos.y = SHFL(maxPos.y, 0);
maxPos.z = SHFL(maxPos.z, 0);
#endif
real4 blockSize = 0.5f*(maxPos-minPos);
real4 center = 0.5f*(maxPos+minPos);
center.w = 0;
real4 delta = tPos - center;
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
APPLY_PERIODIC_TO_DELTA(delta) APPLY_PERIODIC_TO_DELTA(delta)
#endif #endif
center.w = max(center.w, delta.x*delta.x+delta.y*delta.y+delta.z*delta.z); real tdelta = delta.x*delta.x+delta.y*delta.y+delta.z*delta.z;
} real tcenter = max(center.w, tdelta);
center.w = sqrt(center.w); for (int i = TILE_SIZE >> 1; i > 0; i >>= 1) {
real t = __shfl_down(tcenter, i, TILE_SIZE);
tcenter = max(tcenter, t);
}
if (indexInTile == 0) {
center.w = SQRT(tcenter);
blockBoundingBox[index] = blockSize; blockBoundingBox[index] = blockSize;
blockCenter[index] = center; blockCenter[index] = center;
sortedBlocks[index] = make_real2(blockSize.x+blockSize.y+blockSize.z, index); // blockSize.x+blockSize.y+blockSize.z has a distibution that looks like a normal distribution.
index += blockDim.x*gridDim.x; // This causes HipSort's buckets to have very non-uniform sizes, so a few very long buckets are
base = index*TILE_SIZE; // sorted in global memory. -1/max(x, y, z) or -1/(x+y+z) have a "faster" distribution.
sortedBlocks[index] = make_real2(-RECIP(max(max(blockSize.x, blockSize.y), blockSize.z)), index);
} }
if (blockIdx.x == 0 && threadIdx.x == 0) if (blockIdx.x == 0 && threadIdx.x == 0)
rebuildNeighborList[0] = 0; rebuildNeighborList[0] = 0;
...@@ -65,7 +113,8 @@ extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, co ...@@ -65,7 +113,8 @@ extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, co
const real4* __restrict__ blockBoundingBox, real4* __restrict__ sortedBlockCenter, const real4* __restrict__ blockBoundingBox, real4* __restrict__ sortedBlockCenter,
real4* __restrict__ sortedBlockBoundingBox, const real4* __restrict__ posq, const real4* __restrict__ oldPositions, real4* __restrict__ sortedBlockBoundingBox, const real4* __restrict__ posq, const real4* __restrict__ oldPositions,
unsigned int* __restrict__ interactionCount, int* __restrict__ rebuildNeighborList, bool forceRebuild) { unsigned int* __restrict__ interactionCount, int* __restrict__ rebuildNeighborList, bool forceRebuild) {
for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_BLOCKS; i += blockDim.x*gridDim.x) { int i = threadIdx.x+blockIdx.x*blockDim.x;
if (i < NUM_BLOCKS) {
int index = (int) sortedBlock[i].y; int index = (int) sortedBlock[i].y;
sortedBlockCenter[i] = blockCenter[index]; sortedBlockCenter[i] = blockCenter[index];
sortedBlockBoundingBox[i] = blockBoundingBox[index]; sortedBlockBoundingBox[i] = blockBoundingBox[index];
...@@ -74,11 +123,12 @@ extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, co ...@@ -74,11 +123,12 @@ extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, co
// Also check whether any atom has moved enough so that we really need to rebuild the neighbor list. // Also check whether any atom has moved enough so that we really need to rebuild the neighbor list.
bool rebuild = forceRebuild; bool rebuild = forceRebuild;
for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_ATOMS; i += blockDim.x*gridDim.x) { if (i < NUM_ATOMS) {
real4 delta = oldPositions[i]-posq[i]; real4 delta = oldPositions[i]-posq[i];
if (delta.x*delta.x + delta.y*delta.y + delta.z*delta.z > 0.25f*PADDING*PADDING) if (delta.x*delta.x + delta.y*delta.y + delta.z*delta.z > 0.25f*PADDING*PADDING)
rebuild = true; rebuild = true;
} }
if (rebuild) { if (rebuild) {
rebuildNeighborList[0] = 1; rebuildNeighborList[0] = 1;
interactionCount[0] = 0; interactionCount[0] = 0;
...@@ -86,59 +136,59 @@ extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, co ...@@ -86,59 +136,59 @@ extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, co
} }
} }
__device__ int saveSinglePairs(int x, int* atoms, tileflags* flags, int length, unsigned int maxSinglePairs, unsigned int* singlePairCount, int2* singlePairs, int* sumBuffer, volatile int& pairStartIndex) { #if MAX_BITS_FOR_PAIRS > 0
// Record interactions that should be computed as single pairs rather than in blocks.
const int indexInWarp = threadIdx.x%warpSize; __device__ inline
int sum = 0; void collectInteractions(unsigned int& interacts, float d, int bit) {
for (int i = indexInWarp; i < length; i += warpSize) { interacts |= (__float_as_uint(d) >> 31) << bit;
int count = __TILEFLAGS_POPC(flags[i]); }
sum += (count <= MAX_BITS_FOR_PAIRS ? count : 0);
}
for (int i = 1; i < warpSize; i *= 2) {
int n = SHFL(sum, indexInWarp - i);
if (indexInWarp >= i)
sum += n;
}
if (indexInWarp == warpSize - 1)
pairStartIndex = atomicAdd(singlePairCount,(unsigned int) sum);
SYNC_WARPS;
int prevSum = SHFL(sum, indexInWarp - 1);
int pairIndex = pairStartIndex + (indexInWarp > 0 ? prevSum : 0);
for (int i = indexInWarp; i < length; i += warpSize) {
int count = __TILEFLAGS_POPC(flags[i]);
if (count <= MAX_BITS_FOR_PAIRS && pairIndex+count <= maxSinglePairs) {
tileflags f = flags[i];
while (f != 0) {
singlePairs[pairIndex] = make_int2(atoms[i], x*TILE_SIZE+__TILEFLAGS_FFS(f)-1);
f &= f-1;
pairIndex++;
}
}
}
// Compact the remaining interactions. __device__ inline
void collectInteractions(unsigned int& interacts, double d, int bit) {
const tileflags warpMask = (static_cast<tileflags>(1)<<indexInWarp)-1; interacts |= ((unsigned int)__double2hiint(d) >> 31) << bit;
int numCompacted = 0; }
for (int start = 0; start < length; start += warpSize) {
int i = start+indexInWarp; #else
int atom = atoms[i];
tileflags flag = flags[i]; // Simplified version that does not collect individual bits (they are not needed without single pairs),
bool include = (i < length && __TILEFLAGS_POPC(flags[i]) > MAX_BITS_FOR_PAIRS); // only sets the flag that there are any interactions.
tileflags includeFlags = BALLOT(include);
if (include) { __device__ inline
int index = numCompacted+__TILEFLAGS_POPC(includeFlags&warpMask); void collectInteractions(unsigned int& interacts, float d, int) {
atoms[index] = atom; interacts |= __float_as_uint(d) & (1 << 31);
flags[index] = flag; }
}
numCompacted += __TILEFLAGS_POPC(includeFlags); __device__ inline
void collectInteractions(unsigned int& interacts, double d, int) {
interacts |= (unsigned int)__double2hiint(d) & (1 << 31);
}
#endif
#if !defined(USE_DOUBLE_PRECISION) && __has_builtin(__builtin_amdgcn_mfma_f32_4x4x1f32)
#define USE_MFMA
using vfloat = __attribute__((__vector_size__(4 * sizeof(float)))) float;
template<int BlockId>
inline __device__
void mfma4x4(const float4& pos1, const float4& pos2, const vfloat& c, unsigned int& interacts) {
vfloat d;
d = __builtin_amdgcn_mfma_f32_4x4x1f32(pos1.x, -pos2.x, c, 4, BlockId, 0);
d = __builtin_amdgcn_mfma_f32_4x4x1f32(pos1.y, -pos2.y, d, 4, BlockId, 0);
d = __builtin_amdgcn_mfma_f32_4x4x1f32(pos1.z, -pos2.z, d, 4, BlockId, 0);
d = __builtin_amdgcn_mfma_f32_4x4x1f32(pos1.w, 1.0f, d, 4, BlockId, 0);
#pragma unroll
for (int i = 0; i < 4; i++) {
collectInteractions(interacts, d[i], BlockId * 4 + i);
} }
return numCompacted;
} }
#endif
/** /**
* Compare the bounding boxes for each pair of atom blocks (comprised of warpSize atoms each), forming a tile. If the two * Compare the bounding boxes for each pair of atom blocks (comprised of TILE_SIZE atoms each), forming a tile. If the two
* atom blocks are sufficiently far apart, mark them as non-interacting. There are two stages in the algorithm. * atom blocks are sufficiently far apart, mark them as non-interacting. There are two stages in the algorithm.
* *
* STAGE 1: * STAGE 1:
...@@ -155,7 +205,7 @@ __device__ int saveSinglePairs(int x, int* atoms, tileflags* flags, int length, ...@@ -155,7 +205,7 @@ __device__ int saveSinglePairs(int x, int* atoms, tileflags* flags, int length,
* A fine grained atom block against interacting atoms neighbour list is constructed. * A fine grained atom block against interacting atoms neighbour list is constructed.
* *
* The warp loops over atom blocks Y that were found to (possibly) interact with atom block X. Each thread * The warp loops over atom blocks Y that were found to (possibly) interact with atom block X. Each thread
* in the warp loops over the warpSize atoms in X and compares their positions to one particular atom from block Y. * in the warp loops over the TILE_SIZE atoms in X and compares their positions to one particular atom from block Y.
* If it finds one closer than the cutoff distance, the atom is added to the list of atoms interacting with block X. * If it finds one closer than the cutoff distance, the atom is added to the list of atoms interacting with block X.
* This continues until the buffer fills up, at which point the results are written to global memory. * This continues until the buffer fills up, at which point the results are written to global memory.
* *
...@@ -186,7 +236,7 @@ __device__ int saveSinglePairs(int x, int* atoms, tileflags* flags, int length, ...@@ -186,7 +236,7 @@ __device__ int saveSinglePairs(int x, int* atoms, tileflags* flags, int length,
* [in] rebuildNeighbourList - whether or not to execute this kernel * [in] rebuildNeighbourList - whether or not to execute this kernel
* *
*/ */
extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ, extern "C" __global__ __launch_bounds__(GROUP_SIZE) void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
unsigned int* __restrict__ interactionCount, int* __restrict__ interactingTiles, unsigned int* __restrict__ interactingAtoms, unsigned int* __restrict__ interactionCount, int* __restrict__ interactingTiles, unsigned int* __restrict__ interactingAtoms,
int2* __restrict__ singlePairs, const real4* __restrict__ posq, unsigned int maxTiles, unsigned int maxSinglePairs, int2* __restrict__ singlePairs, const real4* __restrict__ posq, unsigned int maxTiles, unsigned int maxSinglePairs,
unsigned int startBlockIndex, unsigned int numBlocks, real2* __restrict__ sortedBlocks, const real4* __restrict__ sortedBlockCenter, unsigned int startBlockIndex, unsigned int numBlocks, real2* __restrict__ sortedBlocks, const real4* __restrict__ sortedBlockCenter,
...@@ -196,27 +246,29 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea ...@@ -196,27 +246,29 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea
if (rebuildNeighborList[0] == 0) if (rebuildNeighborList[0] == 0)
return; // The neighbor list doesn't need to be rebuilt. return; // The neighbor list doesn't need to be rebuilt.
constexpr int tilesPerWarp = warpSize/TILE_SIZE;
constexpr int warpsPerBlock = GROUP_SIZE/warpSize;
const int indexInWarp = threadIdx.x%warpSize; const int indexInWarp = threadIdx.x%warpSize;
const int warpStart = threadIdx.x-indexInWarp; const int indexInTile = threadIdx.x%TILE_SIZE;
const int totalWarps = blockDim.x*gridDim.x/warpSize; const int tileInWarp = tilesPerWarp == 1 ? 0 : (threadIdx.x/TILE_SIZE)%tilesPerWarp;
const int warpIndex = (blockIdx.x*blockDim.x+threadIdx.x)/warpSize; const int warpInBlock = warpsPerBlock == 1 ? 0 : threadIdx.x/warpSize;
const tileflags warpMask = (static_cast<tileflags>(1)<<indexInWarp)-1; const int warpIndex = blockIdx.x*warpsPerBlock + (warpsPerBlock == 1 ? 0 : warpInBlock);
__shared__ int workgroupBuffer[BUFFER_SIZE*(GROUP_SIZE/warpSize)]; const warpflags warpMask = (static_cast<warpflags>(1)<<indexInWarp)-1;
__shared__ tileflags workgroupFlagsBuffer[BUFFER_SIZE*(GROUP_SIZE/warpSize)];
__shared__ int warpExclusions[MAX_EXCLUSIONS*(GROUP_SIZE/warpSize)]; __shared__ int workgroupBuffer[BUFFER_SIZE*warpsPerBlock];
__shared__ real4 posBuffer[GROUP_SIZE]; __shared__ real4 workgroupPosBuffer[TILE_SIZE*warpsPerBlock];
__shared__ volatile int workgroupTileIndex[GROUP_SIZE/warpSize]; __shared__ int workgroupExclusions[MAX_EXCLUSIONS*warpsPerBlock];
__shared__ int worksgroupPairStartIndex[GROUP_SIZE/warpSize]; __shared__ int workgroupBlock2Buffer[(warpSize+1)*warpsPerBlock];
int* sumBuffer = (int*) posBuffer; // Reuse the same buffer to save memory
int* buffer = workgroupBuffer+BUFFER_SIZE*(warpStart/warpSize); int* buffer = workgroupBuffer+BUFFER_SIZE*warpInBlock;
tileflags* flagsBuffer = workgroupFlagsBuffer+BUFFER_SIZE*(warpStart/warpSize); real4* posBuffer = workgroupPosBuffer+TILE_SIZE*warpInBlock;
int* exclusionsForX = warpExclusions+MAX_EXCLUSIONS*(warpStart/warpSize); int* exclusionsForX = workgroupExclusions+MAX_EXCLUSIONS*warpInBlock;
volatile int& tileStartIndex = workgroupTileIndex[warpStart/warpSize]; int* block2Buffer = workgroupBlock2Buffer+(warpSize+1)*warpInBlock;
volatile int& pairStartIndex = worksgroupPairStartIndex[warpStart/warpSize];
// Loop over blocks. // Loop over blocks.
for (int block1 = startBlockIndex+warpIndex; block1 < startBlockIndex+numBlocks; block1 += totalWarps) { int block1 = startBlockIndex+warpIndex/NUM_TILES_IN_BATCH;
if (block1 < startBlockIndex+numBlocks) {
// Load data for this block. Note that all threads in a warp are processing the same block. // Load data for this block. Note that all threads in a warp are processing the same block.
real2 sortedKey = sortedBlocks[block1]; real2 sortedKey = sortedBlocks[block1];
...@@ -224,7 +276,7 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea ...@@ -224,7 +276,7 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea
real4 blockCenterX = sortedBlockCenter[block1]; real4 blockCenterX = sortedBlockCenter[block1];
real4 blockSizeX = sortedBlockBoundingBox[block1]; real4 blockSizeX = sortedBlockBoundingBox[block1];
int neighborsInBuffer = 0; int neighborsInBuffer = 0;
real4 pos1 = posq[x*TILE_SIZE+indexInWarp]; real4 pos1 = posq[x*TILE_SIZE+indexInTile];
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
const bool singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= PADDED_CUTOFF && const bool singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= PADDED_CUTOFF &&
0.5f*periodicBoxSize.y-blockSizeX.y >= PADDED_CUTOFF && 0.5f*periodicBoxSize.y-blockSizeX.y >= PADDED_CUTOFF &&
...@@ -237,64 +289,86 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea ...@@ -237,64 +289,86 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea
} }
#endif #endif
pos1.w = 0.5f * (pos1.x * pos1.x + pos1.y * pos1.y + pos1.z * pos1.z); pos1.w = 0.5f * (pos1.x * pos1.x + pos1.y * pos1.y + pos1.z * pos1.z);
posBuffer[threadIdx.x] = pos1; if (tileInWarp == 0) {
posBuffer[indexInTile] = pos1;
}
// Load exclusion data for block x. // Load exclusion data for block x.
const int exclusionStart = exclusionRowIndices[x]; const int exclusionStart = exclusionRowIndices[x];
const int exclusionEnd = exclusionRowIndices[x+1]; const int exclusionEnd = exclusionRowIndices[x+1];
const int numExclusions = exclusionEnd-exclusionStart; const int numExclusions = exclusionEnd-exclusionStart;
#pragma unroll 1
for (int j = indexInWarp; j < numExclusions; j += warpSize) for (int j = indexInWarp; j < numExclusions; j += warpSize)
exclusionsForX[j] = exclusionIndices[exclusionStart+j]; exclusionsForX[j] = exclusionIndices[exclusionStart+j];
if (MAX_EXCLUSIONS > warpSize)
__syncthreads();
// Loop over atom blocks to search for neighbors. The threads in a warp compare block1 against warpSize // Loop over atom blocks to search for neighbors. The threads in a warp compare block1 against warpSize
// other blocks in parallel. // other blocks in parallel.
// For small systems multiple warps (NUM_TILES_IN_BATCH = 4, 2...) process one block1 reducing the overall
for (int block2Base = block1+1; block2Base < NUM_BLOCKS; block2Base += warpSize) { // duration of the kernel because first blocks block1 have to process more block2 blocks so most of compute
// units are idle at the end of the kernel (the kernel works on the upper triangle of
// the NUM_BLOCKS x NUM_BLOCKS matrix).
int block2Count = 0;
// Load blocks from addresses aligned by warpSize for faster loading from sortedBlockCenter and sortedBlockBoundingBox.
for (int block2Base = ((block1+1)/warpSize + warpIndex%NUM_TILES_IN_BATCH)*warpSize; block2Base < NUM_BLOCKS; block2Base += warpSize*NUM_TILES_IN_BATCH) {
const bool lastIteration = block2Base + warpSize*NUM_TILES_IN_BATCH >= NUM_BLOCKS;
int block2 = block2Base+indexInWarp; int block2 = block2Base+indexInWarp;
bool includeBlock2 = (block2 < NUM_BLOCKS); bool includeBlock2 = (block1 < block2 && block2 < NUM_BLOCKS);
block2 = includeBlock2 ? block2 : block1;
bool forceInclude = false; bool forceInclude = false;
if (includeBlock2) { real4 blockCenterY = sortedBlockCenter[block2];
real4 blockCenterY = sortedBlockCenter[block2]; real4 blockDelta = blockCenterX-blockCenterY;
real4 blockSizeY = sortedBlockBoundingBox[block2];
real4 blockDelta = blockCenterX-blockCenterY;
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
APPLY_PERIODIC_TO_DELTA(blockDelta) APPLY_PERIODIC_TO_DELTA(blockDelta)
#endif
includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < (PADDED_CUTOFF+blockCenterX.w+blockCenterY.w)*(PADDED_CUTOFF+blockCenterX.w+blockCenterY.w));
#ifndef TRICLINIC
if (!lastIteration && __ballot(includeBlock2) == 0)
continue;
#endif #endif
includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < (PADDED_CUTOFF+blockCenterX.w+blockCenterY.w)*(PADDED_CUTOFF+blockCenterX.w+blockCenterY.w)); real4 blockSizeY = sortedBlockBoundingBox[block2];
blockDelta.x = max(0.0f, fabs(blockDelta.x)-blockSizeX.x-blockSizeY.x); blockDelta.x = max(0.0f, fabs(blockDelta.x)-blockSizeX.x-blockSizeY.x);
blockDelta.y = max(0.0f, fabs(blockDelta.y)-blockSizeX.y-blockSizeY.y); blockDelta.y = max(0.0f, fabs(blockDelta.y)-blockSizeX.y-blockSizeY.y);
blockDelta.z = max(0.0f, fabs(blockDelta.z)-blockSizeX.z-blockSizeY.z); blockDelta.z = max(0.0f, fabs(blockDelta.z)-blockSizeX.z-blockSizeY.z);
includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < PADDED_CUTOFF_SQUARED); includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < PADDED_CUTOFF_SQUARED);
#ifdef TRICLINIC #ifdef TRICLINIC
// The calculation to find the nearest periodic copy is only guaranteed to work if the nearest copy is less than half a box width away. // The calculation to find the nearest periodic copy is only guaranteed to work if the nearest copy is less than half a box width away.
// If there's any possibility we might have missed it, do a detailed check. // If there's any possibility we might have missed it, do a detailed check.
if (periodicBoxSize.z/2-blockSizeX.z-blockSizeY.z < PADDED_CUTOFF || periodicBoxSize.y/2-blockSizeX.y-blockSizeY.y < PADDED_CUTOFF) if (periodicBoxSize.z/2-blockSizeX.z-blockSizeY.z < PADDED_CUTOFF || periodicBoxSize.y/2-blockSizeX.y-blockSizeY.y < PADDED_CUTOFF)
includeBlock2 = forceInclude = true; includeBlock2 = forceInclude = true;
#endif #endif
if (includeBlock2) {
int y = (int) sortedBlocks[block2].y; // Collect any blocks we identified as potentially containing neighbors.
for (int k = 0; k < numExclusions; k++)
includeBlock2 &= (exclusionsForX[k] != y); warpflags includeBlockFlags = __ballot(includeBlock2);
} if (includeBlock2) {
int index = block2Count + warpPopc(includeBlockFlags&warpMask);
block2Buffer[index] = (block2 << 1) | (forceInclude ? 1 : 0);
} }
block2Count += warpPopc(includeBlockFlags);
// Loop over the collected candidates (each warp processes 2 blocks on CDNA or 1 block on RDNA).
// Process even number of blocks on CDNA so both half-warps have work to do (except for
// the last iteration of the for-block2Base loop when the left-over must be processed).
// Loop over any blocks we identified as potentially containing neighbors. const int block2ToProcess = lastIteration ? block2Count : block2Count/tilesPerWarp*tilesPerWarp;
for (int block2Index = 0; block2Index < block2ToProcess; block2Index += tilesPerWarp) {
bool includeBlock2 = block2Index + tileInWarp < block2Count;
const int b = block2Buffer[min(block2Index + tileInWarp, block2Count - 1)];
const bool forceInclude = b & 1;
const int block2 = b >> 1;
int y = (int) sortedBlocks[block2].y;
tileflags includeBlockFlags = BALLOT(includeBlock2); #pragma unroll 1
tileflags forceIncludeFlags = BALLOT(forceInclude); for (int k = indexInTile; k < numExclusions; k += TILE_SIZE)
while (includeBlockFlags != 0) { includeBlock2 &= (exclusionsForX[k] != y);
int i = __TILEFLAGS_FFS(includeBlockFlags)-1; includeBlock2 = BALLOT(!includeBlock2) == 0;
includeBlockFlags &= includeBlockFlags-1;
forceInclude = (forceIncludeFlags>>i) & 1;
int y = (int) sortedBlocks[block2Base+i].y;
// Check each atom in block Y for interactions. // Check each atom in block Y for interactions.
int atom2 = y*TILE_SIZE+indexInWarp; int atom2 = y*TILE_SIZE+indexInTile;
real4 pos2 = posq[atom2]; real4 pos2 = posq[atom2];
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
if (singlePeriodicCopy) { if (singlePeriodicCopy) {
...@@ -302,92 +376,145 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea ...@@ -302,92 +376,145 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea
} }
#endif #endif
pos2.w = 0.5f * (pos2.x * pos2.x + pos2.y * pos2.y + pos2.z * pos2.z); pos2.w = 0.5f * (pos2.x * pos2.x + pos2.y * pos2.y + pos2.z * pos2.z);
real4 blockCenterY = sortedBlockCenter[block2Base+i]; real4 blockCenterY = sortedBlockCenter[block2];
real3 atomDelta = trimTo3(posBuffer[warpStart+indexInWarp])-trimTo3(blockCenterY); real3 atomDelta = trimTo3(pos1)-trimTo3(blockCenterY);
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
APPLY_PERIODIC_TO_DELTA(atomDelta) APPLY_PERIODIC_TO_DELTA(atomDelta)
#endif #endif
tileflags atomFlags = BALLOT(forceInclude || atomDelta.x*atomDelta.x+atomDelta.y*atomDelta.y+atomDelta.z*atomDelta.z < (PADDED_CUTOFF+blockCenterY.w)*(PADDED_CUTOFF+blockCenterY.w)); tileflags atomFlags = BALLOT(forceInclude || atomDelta.x*atomDelta.x+atomDelta.y*atomDelta.y+atomDelta.z*atomDelta.z < (PADDED_CUTOFF+blockCenterY.w)*(PADDED_CUTOFF+blockCenterY.w));
tileflags interacts = 0; tileflags interacts = 0;
if (atom2 < NUM_ATOMS && atomFlags != 0) { // The condition `posj.w + pos2.w - posj.x*pos2.x - posj.y*pos2.y - posj.z*pos2.z < 0.5f * PADDED_CUTOFF_SQUARED` is expressed as
int first = __TILEFLAGS_FFS(atomFlags)-1; // `posj.x*pos2.x - posj.y*pos2.y - posj.z*pos2.z - posj.w - 0.5f * PADDED_CUTOFF_SQUARED - pos2.w` and computed using fma
int last = warpSize-__TILEFLAGS_CLZ(atomFlags); // (it saves 1 instruction).
// Sign bit is used directly instead of `halfDist2 < 0.5f * PADDED_CUTOFF_SQUARED ? 1<<j : 0`.
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
if (!singlePeriodicCopy) { if (!singlePeriodicCopy) {
for (int j = first; j < last; j++) { while (atomFlags) {
real3 delta = trimTo3(pos2)-trimTo3(posBuffer[warpStart+j]); int j = __ffs(atomFlags)-1;
APPLY_PERIODIC_TO_DELTA(delta) atomFlags = atomFlags ^ (static_cast<tileflags>(1) << j);
interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED ? static_cast<tileflags>(1)<<j : 0); real3 delta = trimTo3(pos2)-trimTo3(posBuffer[j]);
} APPLY_PERIODIC_TO_DELTA(delta)
real d = delta.x*delta.x+delta.y*delta.y+delta.z*delta.z - PADDED_CUTOFF_SQUARED;
collectInteractions(interacts, d, j);
} }
else { }
else {
#endif #endif
#pragma unroll const real lim = 0.5f * PADDED_CUTOFF_SQUARED - pos2.w;
for (int j = 0; j < warpSize; j++) { #if defined(USE_MFMA)
real4 posj = posBuffer[warpStart+j]; const vfloat c = { -lim, -lim, -lim, -lim };
real halfDist2 = posj.w + pos2.w - posj.x*pos2.x - posj.y*pos2.y - posj.z*pos2.z; mfma4x4<0>(pos1, pos2, c, interacts);
interacts |= (halfDist2 < 0.5f * PADDED_CUTOFF_SQUARED ? static_cast<tileflags>(1)<<j : 0); mfma4x4<1>(pos1, pos2, c, interacts);
} mfma4x4<2>(pos1, pos2, c, interacts);
#ifdef USE_PERIODIC mfma4x4<3>(pos1, pos2, c, interacts);
mfma4x4<4>(pos1, pos2, c, interacts);
mfma4x4<5>(pos1, pos2, c, interacts);
mfma4x4<6>(pos1, pos2, c, interacts);
mfma4x4<7>(pos1, pos2, c, interacts);
#else
while (atomFlags) {
int j = __ffs(atomFlags)-1;
atomFlags = atomFlags ^ (static_cast<tileflags>(1) << j);
real4 posj = posBuffer[j];
real d = fma(-posj.x, pos2.x, fma(-posj.y, pos2.y, fma(-posj.z, pos2.z, posj.w - lim)));
collectInteractions(interacts, d, j);
} }
#endif #endif
#ifdef USE_PERIODIC
}
#endif
if (atom2 >= NUM_ATOMS || !includeBlock2) {
interacts = 0;
}
#if MAX_BITS_FOR_PAIRS > 0
const unsigned int interactCount = __popc(interacts);
// Record interactions that should be computed as single pairs rather than in blocks.
const bool storeAsSinglePair = interactCount > 0 && interactCount <= MAX_BITS_FOR_PAIRS;
if (__ballot(storeAsSinglePair)) {
unsigned int sum = 0;
unsigned int prevSum = 0;
for (int i = 1; i <= MAX_BITS_FOR_PAIRS; i++) {
warpflags b = __ballot(interactCount == i);
sum += warpPopc(b) * i;
prevSum += warpPopc(b&warpMask) * i;
}
unsigned int pairStartIndex = 0;
if (indexInWarp == 0)
pairStartIndex = atomicAdd(&interactionCount[1], sum);
pairStartIndex = __shfl(pairStartIndex, 0);
unsigned int pairIndex = pairStartIndex + prevSum;
if (storeAsSinglePair && pairIndex+interactCount <= maxSinglePairs) {
while (interacts != 0) {
int j = __ffs(interacts)-1;
singlePairs[pairIndex] = make_int2(atom2, x*TILE_SIZE+j);
interacts = interacts ^ (static_cast<tileflags>(1) << j);
pairIndex++;
}
}
} }
#else
const unsigned int interactCount = interacts;
#endif
// Add any interacting atoms to the buffer. // Add any interacting atoms to the buffer.
tileflags includeAtomFlags = BALLOT(interacts); warpflags includeAtomFlags = __ballot(interactCount > MAX_BITS_FOR_PAIRS);
if (interacts) { if (interactCount > MAX_BITS_FOR_PAIRS) {
int index = neighborsInBuffer+__TILEFLAGS_POPC(includeAtomFlags&warpMask); int index = neighborsInBuffer+warpPopc(includeAtomFlags&warpMask);
buffer[index] = atom2; buffer[index] = atom2;
flagsBuffer[index] = interacts;
} }
neighborsInBuffer += __TILEFLAGS_POPC(includeAtomFlags); neighborsInBuffer += warpPopc(includeAtomFlags);
if (neighborsInBuffer > BUFFER_SIZE-TILE_SIZE) { if (neighborsInBuffer > BUFFER_SIZE-warpSize) {
// Store the new tiles to memory. // Store the new tiles to memory.
#if MAX_BITS_FOR_PAIRS > 0 unsigned int tilesToStore = neighborsInBuffer/warpSize*tilesPerWarp;
neighborsInBuffer = saveSinglePairs(x, buffer, flagsBuffer, neighborsInBuffer, maxSinglePairs, &interactionCount[1], singlePairs, sumBuffer+warpStart, pairStartIndex); unsigned int tileStartIndex = 0;
#endif if (indexInWarp == 0)
int tilesToStore = neighborsInBuffer/TILE_SIZE; tileStartIndex = atomicAdd(&interactionCount[0], tilesToStore);
if (tilesToStore > 0) { unsigned int newTileStartIndex = __shfl(tileStartIndex, 0);
if (indexInWarp == 0) if (newTileStartIndex+tilesToStore <= maxTiles) {
tileStartIndex = atomicAdd(&interactionCount[0], tilesToStore); if (indexInWarp < tilesToStore)
int newTileStartIndex = tileStartIndex; interactingTiles[newTileStartIndex+indexInWarp] = x;
if (newTileStartIndex+tilesToStore <= maxTiles) { for (int j = 0; j < tilesToStore/tilesPerWarp; j++)
if (indexInWarp < tilesToStore) interactingAtoms[newTileStartIndex*TILE_SIZE+j*warpSize+indexInWarp] = buffer[j*warpSize+indexInWarp];
interactingTiles[newTileStartIndex+indexInWarp] = x;
for (int j = 0; j < tilesToStore; j++)
interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = buffer[indexInWarp+j*TILE_SIZE];
}
buffer[indexInWarp] = buffer[indexInWarp+TILE_SIZE*tilesToStore];
neighborsInBuffer -= TILE_SIZE*tilesToStore;
} }
if (indexInWarp+TILE_SIZE*tilesToStore < BUFFER_SIZE)
buffer[indexInWarp] = buffer[indexInWarp+TILE_SIZE*tilesToStore];
neighborsInBuffer -= TILE_SIZE*tilesToStore;
} }
} }
// Move not processed blocks to the head of block2Buffer.
if (indexInWarp < block2Count - block2ToProcess)
block2Buffer[indexInWarp] = block2Buffer[block2ToProcess + indexInWarp];
block2Count = block2Count - block2ToProcess;
} }
// If we have a partially filled buffer, store it to memory. // If we have a partially filled buffer, store it to memory.
#if MAX_BITS_FOR_PAIRS > 0
if (neighborsInBuffer > warpSize)
neighborsInBuffer = saveSinglePairs(x, buffer, flagsBuffer, neighborsInBuffer, maxSinglePairs, &interactionCount[1], singlePairs, sumBuffer+warpStart, pairStartIndex);
#endif
if (neighborsInBuffer > 0) { if (neighborsInBuffer > 0) {
int tilesToStore = (neighborsInBuffer+TILE_SIZE-1)/TILE_SIZE; unsigned int tilesToStore = (neighborsInBuffer+TILE_SIZE-1)/TILE_SIZE;
unsigned int tileStartIndex = 0;
if (indexInWarp == 0) if (indexInWarp == 0)
tileStartIndex = atomicAdd(&interactionCount[0], tilesToStore); tileStartIndex = atomicAdd(&interactionCount[0], tilesToStore);
int newTileStartIndex = tileStartIndex; unsigned int newTileStartIndex = __shfl(tileStartIndex, 0);
if (newTileStartIndex+tilesToStore <= maxTiles) { if (newTileStartIndex+tilesToStore <= maxTiles) {
if (indexInWarp < tilesToStore) if (indexInWarp < tilesToStore)
interactingTiles[newTileStartIndex+indexInWarp] = x; interactingTiles[newTileStartIndex+indexInWarp] = x;
for (int j = 0; j < tilesToStore; j++) for (int j = 0; j <= tilesToStore/tilesPerWarp; j++) {
interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = (indexInWarp+j*TILE_SIZE < neighborsInBuffer ? buffer[indexInWarp+j*TILE_SIZE] : NUM_ATOMS); if (j*warpSize+indexInWarp < tilesToStore*TILE_SIZE)
interactingAtoms[newTileStartIndex*TILE_SIZE+j*warpSize+indexInWarp] = (j*warpSize+indexInWarp < neighborsInBuffer ? buffer[j*warpSize+indexInWarp] : PADDED_NUM_ATOMS);
}
} }
} }
} }
// Record the positions the neighbor list is based on. // Record the positions the neighbor list is based on.
for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_ATOMS; i += blockDim.x*gridDim.x) int i = threadIdx.x+blockIdx.x*GROUP_SIZE;
if (i < NUM_ATOMS) {
oldPositions[i] = posq[i]; oldPositions[i] = posq[i];
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment