Unverified Commit ffcabcf6 authored by David Clark's avatar David Clark Committed by GitHub
Browse files

Implements findBlocksWithInteractions with matrix multiplication (#2989)



* Frames distance calculation as matrix multiplciation

* Adds comment explaining distance calculation

* Tunes launch bound for cuda11.2

* Simplifies the effective matrix multiplication
Co-authored-by: default avatarDavid Clark <daclark@nvidia.com>
parent 909125ff
...@@ -81,6 +81,7 @@ __device__ int saveSinglePairs(int x, int* atoms, int* flags, int length, unsign ...@@ -81,6 +81,7 @@ __device__ int saveSinglePairs(int x, int* atoms, int* flags, int length, unsign
const int indexInWarp = threadIdx.x%32; const int indexInWarp = threadIdx.x%32;
int sum = 0; int sum = 0;
#pragma unroll 8 // (GROUP_SIZE / TILE_SIZE)
for (int i = indexInWarp; i < length; i += 32) { for (int i = indexInWarp; i < length; i += 32) {
int count = __popc(flags[i]); int count = __popc(flags[i]);
sum += (count <= MAX_BITS_FOR_PAIRS ? count : 0); sum += (count <= MAX_BITS_FOR_PAIRS ? count : 0);
...@@ -176,7 +177,7 @@ __device__ int saveSinglePairs(int x, int* atoms, int* flags, int length, unsign ...@@ -176,7 +177,7 @@ __device__ int saveSinglePairs(int x, int* atoms, int* flags, int length, unsign
* [in] rebuildNeighbourList - whether or not to execute this kernel * [in] rebuildNeighbourList - whether or not to execute this kernel
* *
*/ */
extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ, extern "C" __global__ __launch_bounds__(GROUP_SIZE,3) void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
unsigned int* __restrict__ interactionCount, int* __restrict__ interactingTiles, unsigned int* __restrict__ interactingAtoms, unsigned int* __restrict__ interactionCount, int* __restrict__ interactingTiles, unsigned int* __restrict__ interactingAtoms,
int2* __restrict__ singlePairs, const real4* __restrict__ posq, unsigned int maxTiles, unsigned int maxSinglePairs, int2* __restrict__ singlePairs, const real4* __restrict__ posq, unsigned int maxTiles, unsigned int maxSinglePairs,
unsigned int startBlockIndex, unsigned int numBlocks, real2* __restrict__ sortedBlocks, const real4* __restrict__ sortedBlockCenter, unsigned int startBlockIndex, unsigned int numBlocks, real2* __restrict__ sortedBlocks, const real4* __restrict__ sortedBlockCenter,
...@@ -194,7 +195,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac ...@@ -194,7 +195,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
__shared__ int workgroupBuffer[BUFFER_SIZE*(GROUP_SIZE/32)]; __shared__ int workgroupBuffer[BUFFER_SIZE*(GROUP_SIZE/32)];
__shared__ int workgroupFlagsBuffer[BUFFER_SIZE*(GROUP_SIZE/32)]; __shared__ int workgroupFlagsBuffer[BUFFER_SIZE*(GROUP_SIZE/32)];
__shared__ int warpExclusions[MAX_EXCLUSIONS*(GROUP_SIZE/32)]; __shared__ int warpExclusions[MAX_EXCLUSIONS*(GROUP_SIZE/32)];
__shared__ real3 posBuffer[GROUP_SIZE]; __shared__ real4 posBuffer[GROUP_SIZE];
__shared__ volatile int workgroupTileIndex[GROUP_SIZE/32]; __shared__ volatile int workgroupTileIndex[GROUP_SIZE/32];
__shared__ int worksgroupPairStartIndex[GROUP_SIZE/32]; __shared__ int worksgroupPairStartIndex[GROUP_SIZE/32];
int* sumBuffer = (int*) posBuffer; // Reuse the same buffer to save memory int* sumBuffer = (int*) posBuffer; // Reuse the same buffer to save memory
...@@ -214,7 +215,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac ...@@ -214,7 +215,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
real4 blockCenterX = sortedBlockCenter[block1]; real4 blockCenterX = sortedBlockCenter[block1];
real4 blockSizeX = sortedBlockBoundingBox[block1]; real4 blockSizeX = sortedBlockBoundingBox[block1];
int neighborsInBuffer = 0; int neighborsInBuffer = 0;
real3 pos1 = trimTo3(posq[x*TILE_SIZE+indexInWarp]); real4 pos1 = posq[x*TILE_SIZE+indexInWarp];
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
const bool singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= PADDED_CUTOFF && const bool singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= PADDED_CUTOFF &&
0.5f*periodicBoxSize.y-blockSizeX.y >= PADDED_CUTOFF && 0.5f*periodicBoxSize.y-blockSizeX.y >= PADDED_CUTOFF &&
...@@ -226,6 +227,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac ...@@ -226,6 +227,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
APPLY_PERIODIC_TO_POS_WITH_CENTER(pos1, blockCenterX) APPLY_PERIODIC_TO_POS_WITH_CENTER(pos1, blockCenterX)
} }
#endif #endif
pos1.w = 0.5f * (pos1.x * pos1.x + pos1.y * pos1.y + pos1.z * pos1.z);
posBuffer[threadIdx.x] = pos1; posBuffer[threadIdx.x] = pos1;
// Load exclusion data for block x. // Load exclusion data for block x.
...@@ -233,6 +235,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac ...@@ -233,6 +235,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
const int exclusionStart = exclusionRowIndices[x]; const int exclusionStart = exclusionRowIndices[x];
const int exclusionEnd = exclusionRowIndices[x+1]; const int exclusionEnd = exclusionRowIndices[x+1];
const int numExclusions = exclusionEnd-exclusionStart; const int numExclusions = exclusionEnd-exclusionStart;
#pragma unroll 4 // (MAX_EXCLUSIONS)
for (int j = indexInWarp; j < numExclusions; j += 32) for (int j = indexInWarp; j < numExclusions; j += 32)
exclusionsForX[j] = exclusionIndices[exclusionStart+j]; exclusionsForX[j] = exclusionIndices[exclusionStart+j];
if (MAX_EXCLUSIONS > 32) if (MAX_EXCLUSIONS > 32)
...@@ -266,6 +269,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac ...@@ -266,6 +269,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
#endif #endif
if (includeBlock2) { if (includeBlock2) {
int y = (int) sortedBlocks[block2].y; int y = (int) sortedBlocks[block2].y;
#pragma unroll 4 // (MAX_EXCLUSIONS)
for (int k = 0; k < numExclusions; k++) for (int k = 0; k < numExclusions; k++)
includeBlock2 &= (exclusionsForX[k] != y); includeBlock2 &= (exclusionsForX[k] != y);
} }
...@@ -284,14 +288,16 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac ...@@ -284,14 +288,16 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
// Check each atom in block Y for interactions. // Check each atom in block Y for interactions.
int atom2 = y*TILE_SIZE+indexInWarp; int atom2 = y*TILE_SIZE+indexInWarp;
real3 pos2 = trimTo3(posq[atom2]); real4 pos2 = posq[atom2];
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
if (singlePeriodicCopy) { if (singlePeriodicCopy) {
APPLY_PERIODIC_TO_POS_WITH_CENTER(pos2, blockCenterX) APPLY_PERIODIC_TO_POS_WITH_CENTER(pos2, blockCenterX)
} }
#endif #endif
pos2.w = 0.5f * (pos2.x * pos2.x + pos2.y * pos2.y + pos2.z * pos2.z);
real4 blockCenterY = sortedBlockCenter[block2Base+i]; real4 blockCenterY = sortedBlockCenter[block2Base+i];
real3 atomDelta = posBuffer[warpStart+indexInWarp]-trimTo3(blockCenterY); real3 atomDelta = trimTo3(posBuffer[warpStart+indexInWarp])-trimTo3(blockCenterY);
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
APPLY_PERIODIC_TO_DELTA(atomDelta) APPLY_PERIODIC_TO_DELTA(atomDelta)
#endif #endif
...@@ -303,16 +309,18 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac ...@@ -303,16 +309,18 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
if (!singlePeriodicCopy) { if (!singlePeriodicCopy) {
for (int j = first; j < last; j++) { for (int j = first; j < last; j++) {
real3 delta = pos2-posBuffer[warpStart+j]; real3 delta = trimTo3(pos2)-trimTo3(posBuffer[warpStart+j]);
APPLY_PERIODIC_TO_DELTA(delta) APPLY_PERIODIC_TO_DELTA(delta)
interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED ? 1<<j : 0); interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED ? 1<<j : 0);
} }
} }
else { else {
#endif #endif
for (int j = first; j < last; j++) { #pragma unroll
real3 delta = pos2-posBuffer[warpStart+j]; for (int j = 0; j < 32; j++) {
interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED ? 1<<j : 0); real4 posj = posBuffer[warpStart+j];
real halfDist2 = posj.w + pos2.w - posj.x*pos2.x - posj.y*pos2.y - posj.z*pos2.z;
interacts |= (halfDist2 < 0.5f * PADDED_CUTOFF_SQUARED ? 1<<j : 0);
} }
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
} }
...@@ -342,6 +350,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac ...@@ -342,6 +350,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
if (newTileStartIndex+tilesToStore <= maxTiles) { if (newTileStartIndex+tilesToStore <= maxTiles) {
if (indexInWarp < tilesToStore) if (indexInWarp < tilesToStore)
interactingTiles[newTileStartIndex+indexInWarp] = x; interactingTiles[newTileStartIndex+indexInWarp] = x;
#pragma unroll 8 // (GROUP_SIZE / TILE_SIZE)
for (int j = 0; j < tilesToStore; j++) for (int j = 0; j < tilesToStore; j++)
interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = buffer[indexInWarp+j*TILE_SIZE]; interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = buffer[indexInWarp+j*TILE_SIZE];
} }
...@@ -367,6 +376,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac ...@@ -367,6 +376,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,1) void findBlocksWithInterac
if (newTileStartIndex+tilesToStore <= maxTiles) { if (newTileStartIndex+tilesToStore <= maxTiles) {
if (indexInWarp < tilesToStore) if (indexInWarp < tilesToStore)
interactingTiles[newTileStartIndex+indexInWarp] = x; interactingTiles[newTileStartIndex+indexInWarp] = x;
#pragma unroll 8 // (GROUP_SIZE / TILE_SIZE)
for (int j = 0; j < tilesToStore; j++) for (int j = 0; j < tilesToStore; j++)
interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = (indexInWarp+j*TILE_SIZE < neighborsInBuffer ? buffer[indexInWarp+j*TILE_SIZE] : NUM_ATOMS); interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = (indexInWarp+j*TILE_SIZE < neighborsInBuffer ? buffer[indexInWarp+j*TILE_SIZE] : NUM_ATOMS);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment