Unverified Commit ffcabcf6 authored by David Clark's avatar David Clark Committed by GitHub
Browse files

Implements findBlocksWithInteractions with matrix multiplication (#2989)



* Frames distance calculation as matrix multiplciation

* Adds comment explaining distance calculation

* Tunes launch bound for cuda11.2

* Simplifies the effective matrix multiplication
Co-authored-by: default avatarDavid Clark <daclark@nvidia.com>
parent 909125ff
......@@ -81,6 +81,7 @@ __device__ int saveSinglePairs(int x, int* atoms, int* flags, int length, unsign
const int indexInWarp = threadIdx.x%32;
int sum = 0;
#pragma unroll 8 // (GROUP_SIZE / TILE_SIZE)
for (int i = indexInWarp; i < length; i += 32) {
int count = __popc(flags[i]);
sum += (count <= MAX_BITS_FOR_PAIRS ? count : 0);
......@@ -176,7 +177,7 @@ __device__ int saveSinglePairs(int x, int* atoms, int* flags, int length, unsign
* [in] rebuildNeighbourList - whether or not to execute this kernel
*
*/
extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
extern "C" __global__ __launch_bounds__(GROUP_SIZE,3) void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
unsigned int* __restrict__ interactionCount, int* __restrict__ interactingTiles, unsigned int* __restrict__ interactingAtoms,
int2* __restrict__ singlePairs, const real4* __restrict__ posq, unsigned int maxTiles, unsigned int maxSinglePairs,
unsigned int startBlockIndex, unsigned int numBlocks, real2* __restrict__ sortedBlocks, const real4* __restrict__ sortedBlockCenter,
......@@ -194,7 +195,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
__shared__ int workgroupBuffer[BUFFER_SIZE*(GROUP_SIZE/32)];
__shared__ int workgroupFlagsBuffer[BUFFER_SIZE*(GROUP_SIZE/32)];
__shared__ int warpExclusions[MAX_EXCLUSIONS*(GROUP_SIZE/32)];
__shared__ real3 posBuffer[GROUP_SIZE];
__shared__ real4 posBuffer[GROUP_SIZE];
__shared__ volatile int workgroupTileIndex[GROUP_SIZE/32];
__shared__ int worksgroupPairStartIndex[GROUP_SIZE/32];
int* sumBuffer = (int*) posBuffer; // Reuse the same buffer to save memory
......@@ -214,7 +215,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
real4 blockCenterX = sortedBlockCenter[block1];
real4 blockSizeX = sortedBlockBoundingBox[block1];
int neighborsInBuffer = 0;
real3 pos1 = trimTo3(posq[x*TILE_SIZE+indexInWarp]);
real4 pos1 = posq[x*TILE_SIZE+indexInWarp];
#ifdef USE_PERIODIC
const bool singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= PADDED_CUTOFF &&
0.5f*periodicBoxSize.y-blockSizeX.y >= PADDED_CUTOFF &&
......@@ -226,6 +227,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
APPLY_PERIODIC_TO_POS_WITH_CENTER(pos1, blockCenterX)
}
#endif
pos1.w = 0.5f * (pos1.x * pos1.x + pos1.y * pos1.y + pos1.z * pos1.z);
posBuffer[threadIdx.x] = pos1;
// Load exclusion data for block x.
......@@ -233,6 +235,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
const int exclusionStart = exclusionRowIndices[x];
const int exclusionEnd = exclusionRowIndices[x+1];
const int numExclusions = exclusionEnd-exclusionStart;
#pragma unroll 4 // (MAX_EXCLUSIONS)
for (int j = indexInWarp; j < numExclusions; j += 32)
exclusionsForX[j] = exclusionIndices[exclusionStart+j];
if (MAX_EXCLUSIONS > 32)
......@@ -266,6 +269,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
#endif
if (includeBlock2) {
int y = (int) sortedBlocks[block2].y;
#pragma unroll 4 // (MAX_EXCLUSIONS)
for (int k = 0; k < numExclusions; k++)
includeBlock2 &= (exclusionsForX[k] != y);
}
......@@ -284,14 +288,16 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
// Check each atom in block Y for interactions.
int atom2 = y*TILE_SIZE+indexInWarp;
real3 pos2 = trimTo3(posq[atom2]);
real4 pos2 = posq[atom2];
#ifdef USE_PERIODIC
if (singlePeriodicCopy) {
APPLY_PERIODIC_TO_POS_WITH_CENTER(pos2, blockCenterX)
}
#endif
pos2.w = 0.5f * (pos2.x * pos2.x + pos2.y * pos2.y + pos2.z * pos2.z);
real4 blockCenterY = sortedBlockCenter[block2Base+i];
real3 atomDelta = posBuffer[warpStart+indexInWarp]-trimTo3(blockCenterY);
real3 atomDelta = trimTo3(posBuffer[warpStart+indexInWarp])-trimTo3(blockCenterY);
#ifdef USE_PERIODIC
APPLY_PERIODIC_TO_DELTA(atomDelta)
#endif
......@@ -303,16 +309,18 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
#ifdef USE_PERIODIC
if (!singlePeriodicCopy) {
for (int j = first; j < last; j++) {
real3 delta = pos2-posBuffer[warpStart+j];
real3 delta = trimTo3(pos2)-trimTo3(posBuffer[warpStart+j]);
APPLY_PERIODIC_TO_DELTA(delta)
interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED ? 1<<j : 0);
}
}
else {
#endif
for (int j = first; j < last; j++) {
real3 delta = pos2-posBuffer[warpStart+j];
interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED ? 1<<j : 0);
#pragma unroll
for (int j = 0; j < 32; j++) {
real4 posj = posBuffer[warpStart+j];
real halfDist2 = posj.w + pos2.w - posj.x*pos2.x - posj.y*pos2.y - posj.z*pos2.z;
interacts |= (halfDist2 < 0.5f * PADDED_CUTOFF_SQUARED ? 1<<j : 0);
}
#ifdef USE_PERIODIC
}
......@@ -342,6 +350,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
if (newTileStartIndex+tilesToStore <= maxTiles) {
if (indexInWarp < tilesToStore)
interactingTiles[newTileStartIndex+indexInWarp] = x;
#pragma unroll 8 // (GROUP_SIZE / TILE_SIZE)
for (int j = 0; j < tilesToStore; j++)
interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = buffer[indexInWarp+j*TILE_SIZE];
}
......@@ -367,6 +376,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
if (newTileStartIndex+tilesToStore <= maxTiles) {
if (indexInWarp < tilesToStore)
interactingTiles[newTileStartIndex+indexInWarp] = x;
#pragma unroll 8 // (GROUP_SIZE / TILE_SIZE)
for (int j = 0; j < tilesToStore; j++)
interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = (indexInWarp+j*TILE_SIZE < neighborsInBuffer ? buffer[indexInWarp+j*TILE_SIZE] : NUM_ATOMS);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment