Commit 78f8d078 authored by Peter Eastman's avatar Peter Eastman
Browse files

Very minor optimizations and cleanup to building neighborlist

parent 299b07a0
...@@ -322,8 +322,6 @@ void CudaNonbondedUtilities::initialize(const System& system) { ...@@ -322,8 +322,6 @@ void CudaNonbondedUtilities::initialize(const System& system) {
findInteractingBlocksKernel = context.getKernel(interactingBlocksProgram, "findBlocksWithInteractions"); findInteractingBlocksKernel = context.getKernel(interactingBlocksProgram, "findBlocksWithInteractions");
findInteractingBlocksArgs.push_back(context.getPeriodicBoxSizePointer()); findInteractingBlocksArgs.push_back(context.getPeriodicBoxSizePointer());
findInteractingBlocksArgs.push_back(context.getInvPeriodicBoxSizePointer()); findInteractingBlocksArgs.push_back(context.getInvPeriodicBoxSizePointer());
findInteractingBlocksArgs.push_back(&blockCenter->getDevicePointer());
findInteractingBlocksArgs.push_back(&blockBoundingBox->getDevicePointer());
findInteractingBlocksArgs.push_back(&interactionCount->getDevicePointer()); findInteractingBlocksArgs.push_back(&interactionCount->getDevicePointer());
findInteractingBlocksArgs.push_back(&interactingTiles->getDevicePointer()); findInteractingBlocksArgs.push_back(&interactingTiles->getDevicePointer());
findInteractingBlocksArgs.push_back(&interactingAtoms->getDevicePointer()); findInteractingBlocksArgs.push_back(&interactingAtoms->getDevicePointer());
...@@ -390,10 +388,10 @@ void CudaNonbondedUtilities::updateNeighborListSize() { ...@@ -390,10 +388,10 @@ void CudaNonbondedUtilities::updateNeighborListSize() {
interactingAtoms = CudaArray::create<int>(context, CudaContext::TileSize*maxTiles, "interactingAtoms"); interactingAtoms = CudaArray::create<int>(context, CudaContext::TileSize*maxTiles, "interactingAtoms");
if (forceArgs.size() > 0) if (forceArgs.size() > 0)
forceArgs[7] = &interactingTiles->getDevicePointer(); forceArgs[7] = &interactingTiles->getDevicePointer();
findInteractingBlocksArgs[5] = &interactingTiles->getDevicePointer(); findInteractingBlocksArgs[3] = &interactingTiles->getDevicePointer();
if (forceArgs.size() > 0) if (forceArgs.size() > 0)
forceArgs[14] = &interactingAtoms->getDevicePointer(); forceArgs[14] = &interactingAtoms->getDevicePointer();
findInteractingBlocksArgs[6] = &interactingAtoms->getDevicePointer(); findInteractingBlocksArgs[4] = &interactingAtoms->getDevicePointer();
if (context.getUseDoublePrecision()) { if (context.getUseDoublePrecision()) {
vector<double4> oldPositionsVec(numAtoms, make_double4(1e30, 1e30, 1e30, 0)); vector<double4> oldPositionsVec(numAtoms, make_double4(1e30, 1e30, 1e30, 0));
oldPositions->upload(oldPositionsVec); oldPositions->upload(oldPositionsVec);
......
...@@ -17,7 +17,6 @@ extern "C" __global__ void findBlockBounds(int numAtoms, real4 periodicBoxSize, ...@@ -17,7 +17,6 @@ extern "C" __global__ void findBlockBounds(int numAtoms, real4 periodicBoxSize,
pos.x -= floor(pos.x*invPeriodicBoxSize.x)*periodicBoxSize.x; pos.x -= floor(pos.x*invPeriodicBoxSize.x)*periodicBoxSize.x;
pos.y -= floor(pos.y*invPeriodicBoxSize.y)*periodicBoxSize.y; pos.y -= floor(pos.y*invPeriodicBoxSize.y)*periodicBoxSize.y;
pos.z -= floor(pos.z*invPeriodicBoxSize.z)*periodicBoxSize.z; pos.z -= floor(pos.z*invPeriodicBoxSize.z)*periodicBoxSize.z;
real4 firstPoint = pos;
#endif #endif
real4 minPos = pos; real4 minPos = pos;
real4 maxPos = pos; real4 maxPos = pos;
...@@ -25,9 +24,10 @@ extern "C" __global__ void findBlockBounds(int numAtoms, real4 periodicBoxSize, ...@@ -25,9 +24,10 @@ extern "C" __global__ void findBlockBounds(int numAtoms, real4 periodicBoxSize,
for (int i = base+1; i < last; i++) { for (int i = base+1; i < last; i++) {
pos = posq[i]; pos = posq[i];
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
pos.x -= floor((pos.x-firstPoint.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; real4 center = 0.5f*(maxPos+minPos);
pos.y -= floor((pos.y-firstPoint.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y; pos.x -= floor((pos.x-center.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
pos.z -= floor((pos.z-firstPoint.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z; pos.y -= floor((pos.y-center.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
pos.z -= floor((pos.z-center.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif #endif
minPos = make_real4(min(minPos.x,pos.x), min(minPos.y,pos.y), min(minPos.z,pos.z), 0); minPos = make_real4(min(minPos.x,pos.x), min(minPos.y,pos.y), min(minPos.z,pos.z), 0);
maxPos = make_real4(max(maxPos.x,pos.x), max(maxPos.y,pos.y), max(maxPos.z,pos.z), 0); maxPos = make_real4(max(maxPos.x,pos.x), max(maxPos.y,pos.y), max(maxPos.z,pos.z), 0);
...@@ -264,9 +264,8 @@ __device__ void storeInteractionData(unsigned short x, unsigned short* buffer, s ...@@ -264,9 +264,8 @@ __device__ void storeInteractionData(unsigned short x, unsigned short* buffer, s
* Compare the bounding boxes for each pair of blocks. If they are sufficiently far apart, * Compare the bounding boxes for each pair of blocks. If they are sufficiently far apart,
* mark them as non-interacting. * mark them as non-interacting.
*/ */
extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, const real4* __restrict__ blockCenter, extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, unsigned int* __restrict__ interactionCount,
const real4* __restrict__ blockBoundingBox, unsigned int* __restrict__ interactionCount, ushort2* __restrict__ interactingTiles, ushort2* __restrict__ interactingTiles, unsigned int* __restrict__ interactingAtoms, const real4* __restrict__ posq, unsigned int maxTiles, unsigned int startBlockIndex,
unsigned int* __restrict__ interactingAtoms, const real4* __restrict__ posq, unsigned int maxTiles, unsigned int startBlockIndex,
unsigned int numBlocks, real2* __restrict__ sortedBlocks, const real4* __restrict__ sortedBlockCenter, const real4* __restrict__ sortedBlockBoundingBox, unsigned int numBlocks, real2* __restrict__ sortedBlocks, const real4* __restrict__ sortedBlockCenter, const real4* __restrict__ sortedBlockBoundingBox,
const unsigned int* __restrict__ exclusionIndices, const unsigned int* __restrict__ exclusionRowIndices, real4* __restrict__ oldPositions, const unsigned int* __restrict__ exclusionIndices, const unsigned int* __restrict__ exclusionRowIndices, real4* __restrict__ oldPositions,
const int* __restrict__ rebuildNeighborList) { const int* __restrict__ rebuildNeighborList) {
...@@ -297,8 +296,8 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea ...@@ -297,8 +296,8 @@ extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, rea
numAtoms = 0; numAtoms = 0;
real2 sortedKey = sortedBlocks[i]; real2 sortedKey = sortedBlocks[i];
unsigned short x = (unsigned short) sortedKey.y; unsigned short x = (unsigned short) sortedKey.y;
real4 blockCenterX = blockCenter[x]; real4 blockCenterX = sortedBlockCenter[i];
real4 blockSizeX = blockBoundingBox[x]; real4 blockSizeX = sortedBlockBoundingBox[i];
// Load exclusion data for block x. // Load exclusion data for block x.
......
...@@ -339,22 +339,20 @@ void OpenCLNonbondedUtilities::initialize(const System& system) { ...@@ -339,22 +339,20 @@ void OpenCLNonbondedUtilities::initialize(const System& system) {
sortBoxDataKernel.setArg<cl::Buffer>(7, interactionCount->getDeviceBuffer()); sortBoxDataKernel.setArg<cl::Buffer>(7, interactionCount->getDeviceBuffer());
sortBoxDataKernel.setArg<cl::Buffer>(8, rebuildNeighborList->getDeviceBuffer()); sortBoxDataKernel.setArg<cl::Buffer>(8, rebuildNeighborList->getDeviceBuffer());
findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions"); findInteractingBlocksKernel = cl::Kernel(interactingBlocksProgram, "findBlocksWithInteractions");
findInteractingBlocksKernel.setArg<cl::Buffer>(2, blockCenter->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(2, interactionCount->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(3, blockBoundingBox->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(3, interactingTiles->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(4, interactionCount->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(4, interactingAtoms->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactingTiles->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(5, context.getPosq().getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingAtoms->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl_uint>(6, interactingTiles->getSize());
findInteractingBlocksKernel.setArg<cl::Buffer>(7, context.getPosq().getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl_uint>(7, startBlockIndex);
findInteractingBlocksKernel.setArg<cl_uint>(8, interactingTiles->getSize()); findInteractingBlocksKernel.setArg<cl_uint>(8, numBlocks);
findInteractingBlocksKernel.setArg<cl_uint>(9, startBlockIndex); findInteractingBlocksKernel.setArg<cl::Buffer>(9, sortedBlocks->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl_uint>(10, numBlocks); findInteractingBlocksKernel.setArg<cl::Buffer>(10, sortedBlockCenter->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(11, sortedBlocks->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(11, sortedBlockBoundingBox->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(12, sortedBlockCenter->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(12, exclusionIndices->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(13, sortedBlockBoundingBox->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(13, exclusionRowIndices->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(14, exclusionIndices->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(14, oldPositions->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(15, exclusionRowIndices->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(15, rebuildNeighborList->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(16, oldPositions->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(17, rebuildNeighborList->getDeviceBuffer());
} }
} }
...@@ -430,9 +428,9 @@ void OpenCLNonbondedUtilities::updateNeighborListSize() { ...@@ -430,9 +428,9 @@ void OpenCLNonbondedUtilities::updateNeighborListSize() {
forceKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer()); forceKernel.setArg<cl::Buffer>(7, interactingTiles->getDeviceBuffer());
forceKernel.setArg<cl_uint>(11, maxTiles); forceKernel.setArg<cl_uint>(11, maxTiles);
forceKernel.setArg<cl::Buffer>(14, interactingAtoms->getDeviceBuffer()); forceKernel.setArg<cl::Buffer>(14, interactingAtoms->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(5, interactingTiles->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(3, interactingTiles->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl::Buffer>(6, interactingAtoms->getDeviceBuffer()); findInteractingBlocksKernel.setArg<cl::Buffer>(4, interactingAtoms->getDeviceBuffer());
findInteractingBlocksKernel.setArg<cl_uint>(8, maxTiles); findInteractingBlocksKernel.setArg<cl_uint>(6, maxTiles);
int numAtoms = context.getNumAtoms(); int numAtoms = context.getNumAtoms();
if (context.getUseDoublePrecision()) { if (context.getUseDoublePrecision()) {
vector<mm_double4> oldPositionsVec(numAtoms, mm_double4(1e30, 1e30, 1e30, 0)); vector<mm_double4> oldPositionsVec(numAtoms, mm_double4(1e30, 1e30, 1e30, 0));
...@@ -460,8 +458,8 @@ void OpenCLNonbondedUtilities::setAtomBlockRange(double startFraction, double en ...@@ -460,8 +458,8 @@ void OpenCLNonbondedUtilities::setAtomBlockRange(double startFraction, double en
forceKernel.setArg<cl_uint>(5, startTileIndex); forceKernel.setArg<cl_uint>(5, startTileIndex);
forceKernel.setArg<cl_uint>(6, numTiles); forceKernel.setArg<cl_uint>(6, numTiles);
findInteractingBlocksKernel.setArg<cl_uint>(9, startBlockIndex); findInteractingBlocksKernel.setArg<cl_uint>(7, startBlockIndex);
findInteractingBlocksKernel.setArg<cl_uint>(10, numBlocks); findInteractingBlocksKernel.setArg<cl_uint>(8, numBlocks);
} }
} }
......
...@@ -16,7 +16,6 @@ __kernel void findBlockBounds(int numAtoms, real4 periodicBoxSize, real4 invPeri ...@@ -16,7 +16,6 @@ __kernel void findBlockBounds(int numAtoms, real4 periodicBoxSize, real4 invPeri
real4 pos = posq[base]; real4 pos = posq[base];
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
pos.xyz -= floor(pos.xyz*invPeriodicBoxSize.xyz)*periodicBoxSize.xyz; pos.xyz -= floor(pos.xyz*invPeriodicBoxSize.xyz)*periodicBoxSize.xyz;
real4 firstPoint = pos;
#endif #endif
real4 minPos = pos; real4 minPos = pos;
real4 maxPos = pos; real4 maxPos = pos;
...@@ -24,7 +23,8 @@ __kernel void findBlockBounds(int numAtoms, real4 periodicBoxSize, real4 invPeri ...@@ -24,7 +23,8 @@ __kernel void findBlockBounds(int numAtoms, real4 periodicBoxSize, real4 invPeri
for (int i = base+1; i < last; i++) { for (int i = base+1; i < last; i++) {
pos = posq[i]; pos = posq[i];
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
pos.xyz -= floor((pos.xyz-firstPoint.xyz)*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz; real4 center = 0.5f*(maxPos+minPos);
pos.xyz -= floor((pos.xyz-center.xyz)*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
#endif #endif
minPos = min(minPos, pos); minPos = min(minPos, pos);
maxPos = max(maxPos, pos); maxPos = max(maxPos, pos);
...@@ -233,9 +233,8 @@ void storeInteractionData(unsigned short x, __local unsigned short* buffer, __lo ...@@ -233,9 +233,8 @@ void storeInteractionData(unsigned short x, __local unsigned short* buffer, __lo
* Compare the bounding boxes for each pair of blocks. If they are sufficiently far apart, * Compare the bounding boxes for each pair of blocks. If they are sufficiently far apart,
* mark them as non-interacting. * mark them as non-interacting.
*/ */
__kernel void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, __global const real4* restrict blockCenter, __kernel void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, __global unsigned int* restrict interactionCount,
__global const real4* restrict blockBoundingBox, __global unsigned int* restrict interactionCount, __global ushort2* restrict interactingTiles, __global ushort2* restrict interactingTiles, __global unsigned int* restrict interactingAtoms, __global const real4* restrict posq, unsigned int maxTiles, unsigned int startBlockIndex,
__global unsigned int* restrict interactingAtoms, __global const real4* restrict posq, unsigned int maxTiles, unsigned int startBlockIndex,
unsigned int numBlocks, __global real2* restrict sortedBlocks, __global const real4* restrict sortedBlockCenter, __global const real4* restrict sortedBlockBoundingBox, unsigned int numBlocks, __global real2* restrict sortedBlocks, __global const real4* restrict sortedBlockCenter, __global const real4* restrict sortedBlockBoundingBox,
__global const unsigned int* restrict exclusionIndices, __global const unsigned int* restrict exclusionRowIndices, __global real4* restrict oldPositions, __global const unsigned int* restrict exclusionIndices, __global const unsigned int* restrict exclusionRowIndices, __global real4* restrict oldPositions,
__global const int* restrict rebuildNeighborList) { __global const int* restrict rebuildNeighborList) {
...@@ -274,8 +273,8 @@ __kernel void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodi ...@@ -274,8 +273,8 @@ __kernel void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodi
numAtoms = 0; numAtoms = 0;
real2 sortedKey = sortedBlocks[i]; real2 sortedKey = sortedBlocks[i];
unsigned short x = (unsigned short) sortedKey.y; unsigned short x = (unsigned short) sortedKey.y;
real4 blockCenterX = blockCenter[x]; real4 blockCenterX = sortedBlockCenter[i];
real4 blockSizeX = blockBoundingBox[x]; real4 blockSizeX = sortedBlockBoundingBox[i];
// Load exclusion data for block x. // Load exclusion data for block x.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment