Unverified Commit 2ae50f9d authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Store bounding box sizes in half precision (#4066)

* Store bounding box sizes in half precision

* Work correctly in double precision mode
parent ed049af1
#define GROUP_SIZE 256
#define BUFFER_SIZE 256
/**
* To use half precision, we're supposed to include cuda_fp16.h. Unfortunately,
* it isn't included in the search path automatically, and there's no reliable
* way to find where it's located on disk. Instead we provide our own definitions
* for the few symbols we need.
*/
struct __align__(2) __half {
unsigned short x;
};
__device__ __half __float2half_ru(const float f) {
__half h;
asm("{cvt.rp.f16.f32 %0, %1;}" : "=h"(*reinterpret_cast<unsigned short *>(&h)) : "f"(f));
return h;
}
__device__ float __half2float(const __half h) {
float f;
asm("{cvt.f32.f16 %0, %1;}" : "=f"(f) : "h"(*reinterpret_cast<const unsigned short *>(&h)));
return f;
}
struct half3 {
__device__ half3(real3 f) {
// Round up so we'll err on the side of making the box a little too large.
// This ensures interactions will never be missed.
v[0] = __float2half_ru((float) f.x);
v[1] = __float2half_ru((float) f.y);
v[2] = __float2half_ru((float) f.z);
}
__device__ real3 toReal3() const {
return make_real3(__half2float(v[0]), __half2float(v[1]), __half2float(v[2]));
}
private:
__half v[3];
};
/**
* Find a bounding box for the atoms in each block.
*/
......@@ -53,12 +87,12 @@ extern "C" __global__ void findBlockBounds(int numAtoms, real4 periodicBoxSize,
*/
extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, const real4* __restrict__ blockCenter,
const real4* __restrict__ blockBoundingBox, real4* __restrict__ sortedBlockCenter,
real4* __restrict__ sortedBlockBoundingBox, const real4* __restrict__ posq, const real4* __restrict__ oldPositions,
half3* __restrict__ sortedBlockBoundingBox, const real4* __restrict__ posq, const real4* __restrict__ oldPositions,
unsigned int* __restrict__ interactionCount, int* __restrict__ rebuildNeighborList, bool forceRebuild) {
for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_BLOCKS; i += blockDim.x*gridDim.x) {
int index = (int) sortedBlock[i].y;
sortedBlockCenter[i] = blockCenter[index];
sortedBlockBoundingBox[i] = blockBoundingBox[index];
sortedBlockBoundingBox[i] = half3(trimTo3(blockBoundingBox[index]));
}
// Also check whether any atom has moved enough so that we really need to rebuild the neighbor list.
......@@ -181,7 +215,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,3) void findBlocksWithInterac
unsigned int* __restrict__ interactionCount, int* __restrict__ interactingTiles, unsigned int* __restrict__ interactingAtoms,
int2* __restrict__ singlePairs, const real4* __restrict__ posq, unsigned int maxTiles, unsigned int maxSinglePairs,
unsigned int startBlockIndex, unsigned int numBlocks, real2* __restrict__ sortedBlocks, const real4* __restrict__ sortedBlockCenter,
const real4* __restrict__ sortedBlockBoundingBox, const unsigned int* __restrict__ exclusionIndices, const unsigned int* __restrict__ exclusionRowIndices,
const half3* __restrict__ sortedBlockBoundingBox, const unsigned int* __restrict__ exclusionIndices, const unsigned int* __restrict__ exclusionRowIndices,
real4* __restrict__ oldPositions, const int* __restrict__ rebuildNeighborList) {
if (rebuildNeighborList[0] == 0)
......@@ -213,7 +247,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,3) void findBlocksWithInterac
real2 sortedKey = sortedBlocks[block1];
int x = (int) sortedKey.y;
real4 blockCenterX = sortedBlockCenter[block1];
real4 blockSizeX = sortedBlockBoundingBox[block1];
real3 blockSizeX = sortedBlockBoundingBox[block1].toReal3();
int neighborsInBuffer = 0;
real4 pos1 = posq[x*TILE_SIZE+indexInWarp];
#ifdef USE_PERIODIC
......@@ -250,7 +284,7 @@ extern "C" __global__ __launch_bounds__(GROUP_SIZE,3) void findBlocksWithInterac
bool forceInclude = false;
if (includeBlock2) {
real4 blockCenterY = sortedBlockCenter[block2];
real4 blockSizeY = sortedBlockBoundingBox[block2];
real3 blockSizeY = sortedBlockBoundingBox[block2].toReal3();
real4 blockDelta = blockCenterX-blockCenterY;
#ifdef USE_PERIODIC
APPLY_PERIODIC_TO_DELTA(blockDelta)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment