Commit 077a93c8 authored by Peter Eastman's avatar Peter Eastman
Browse files

Continuing to optimize nonbonded kernels for CPU

parent e7b18ca4
...@@ -215,7 +215,7 @@ void OpenCLNonbondedUtilities::initialize(const System& system) { ...@@ -215,7 +215,7 @@ void OpenCLNonbondedUtilities::initialize(const System& system) {
if (maxInteractingTiles > numTiles) if (maxInteractingTiles > numTiles)
maxInteractingTiles = numTiles; maxInteractingTiles = numTiles;
interactingTiles = new OpenCLArray<mm_ushort2>(context, maxInteractingTiles, "interactingTiles"); interactingTiles = new OpenCLArray<mm_ushort2>(context, maxInteractingTiles, "interactingTiles");
interactionFlags = new OpenCLArray<cl_uint>(context, context.getSIMDWidth() == 32 || deviceIsCpu ? maxInteractingTiles : 1, "interactionFlags"); interactionFlags = new OpenCLArray<cl_uint>(context, context.getSIMDWidth() == 32 ? maxInteractingTiles : (deviceIsCpu ? 2*maxInteractingTiles : 1), "interactionFlags");
interactionCount = new OpenCLArray<cl_uint>(context, 1, "interactionCount", true); interactionCount = new OpenCLArray<cl_uint>(context, 1, "interactionCount", true);
blockCenter = new OpenCLArray<mm_float4>(context, numAtomBlocks, "blockCenter"); blockCenter = new OpenCLArray<mm_float4>(context, numAtomBlocks, "blockCenter");
blockBoundingBox = new OpenCLArray<mm_float4>(context, numAtomBlocks, "blockBoundingBox"); blockBoundingBox = new OpenCLArray<mm_float4>(context, numAtomBlocks, "blockBoundingBox");
...@@ -459,7 +459,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -459,7 +459,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc
kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
kernel.setArg<cl::Buffer>(index++, exclusionIndices->getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, exclusionIndices->getDeviceBuffer());
kernel.setArg<cl::Buffer>(index++, exclusionRowIndices->getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, exclusionRowIndices->getDeviceBuffer());
kernel.setArg(index++, OpenCLContext::ThreadBlockSize*localDataSize, NULL); kernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize*localDataSize : OpenCLContext::ThreadBlockSize*localDataSize), NULL);
kernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float4), NULL); kernel.setArg(index++, OpenCLContext::ThreadBlockSize*sizeof(cl_float4), NULL);
if (useCutoff) { if (useCutoff) {
kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, interactingTiles->getDeviceBuffer());
......
...@@ -45,12 +45,15 @@ __kernel void findBlockBounds(int numAtoms, float4 periodicBoxSize, float4 invPe ...@@ -45,12 +45,15 @@ __kernel void findBlockBounds(int numAtoms, float4 periodicBoxSize, float4 invPe
* This is called by findBlocksWithInteractions(). It compacts the list of blocks and writes them * This is called by findBlocksWithInteractions(). It compacts the list of blocks and writes them
* to global memory. * to global memory.
*/ */
void storeInteractionData(__local ushort2* buffer, int numValid, __local unsigned int* flagsBuffer, __local float4* temp, void storeInteractionData(ushort2* buffer, int numValid, __global unsigned int* interactionCount, __global ushort2* interactingTiles,
__global unsigned int* interactionCount, __global ushort2* interactingTiles, __global unsigned int* interactionFlags, float cutoffSquared, float4 periodicBoxSize, __global unsigned int* interactionFlags, float cutoffSquared, float4 periodicBoxSize, float4 invPeriodicBoxSize,
float4 invPeriodicBoxSize, __global float4* posq, __global float4* blockCenter, __global float4* blockBoundingBox, unsigned int maxTiles) { __global float4* posq, __global float4* blockCenter, __global float4* blockBoundingBox, unsigned int maxTiles) {
// Filter the list of tiles by comparing the distance from each atom to the other bounding box. // Filter the list of tiles by comparing the distance from each atom to the other bounding box.
unsigned int flagsBuffer[2*BUFFER_SIZE];
float4 atomPositions[TILE_SIZE];
int lasty = -1; int lasty = -1;
float4 centery, boxSizey;
for (int tile = 0; tile < numValid; ) { for (int tile = 0; tile < numValid; ) {
int x = buffer[tile].x; int x = buffer[tile].x;
int y = buffer[tile].y; int y = buffer[tile].y;
...@@ -59,37 +62,46 @@ void storeInteractionData(__local ushort2* buffer, int numValid, __local unsigne ...@@ -59,37 +62,46 @@ void storeInteractionData(__local ushort2* buffer, int numValid, __local unsigne
continue; continue;
} }
// Load the atom positions and the bounding box of the other block. // Load the atom positions and bounding boxes.
float4 center = blockCenter[x]; float4 centerx = blockCenter[x];
float4 boxSize = blockBoundingBox[x]; float4 boxSizex = blockBoundingBox[x];
if (y != lasty) if (y != lasty) {
for (int atom = 0; atom < TILE_SIZE; atom++) for (int atom = 0; atom < TILE_SIZE; atom++)
temp[atom] = posq[y*TILE_SIZE+atom]; atomPositions[atom] = posq[y*TILE_SIZE+atom];
lasty = y; centery = blockCenter[y];
boxSizey = blockBoundingBox[y];
lasty = y;
}
// Find the distance of each atom from the bounding box. // Find the distance of each atom from the bounding box.
unsigned int flags = 0; unsigned int flags1 = 0, flags2 = 0;
for (int atom = 0; atom < TILE_SIZE; atom++) { for (int atom = 0; atom < TILE_SIZE; atom++) {
float4 delta = temp[atom]-center; float4 delta = atomPositions[atom]-centerx;
#ifdef USE_PERIODIC
delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
#endif
delta = max((float4) 0.0f, fabs(delta)-boxSizex);
if (dot(delta.xyz, delta.xyz) < cutoffSquared)
flags1 += 1 << atom;
delta = posq[x*TILE_SIZE+atom]-centery;
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif #endif
delta = max((float4) 0.0f, fabs(delta)-boxSize); delta = max((float4) 0.0f, fabs(delta)-boxSizey);
if (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < cutoffSquared) if (dot(delta.xyz, delta.xyz) < cutoffSquared)
flags += 1 << atom; flags2 += 1 << atom;
} }
if (flags == 0) { if (flags1 == 0 || flags2 == 0) {
// This tile contains no interactions. // This tile contains no interactions.
numValid--; numValid--;
buffer[tile] = buffer[numValid]; buffer[tile] = buffer[numValid];
} }
else { else {
flagsBuffer[tile] = flags; flagsBuffer[2*tile] = flags1;
flagsBuffer[2*tile+1] = flags2;
tile++; tile++;
} }
} }
...@@ -100,7 +112,8 @@ void storeInteractionData(__local ushort2* buffer, int numValid, __local unsigne ...@@ -100,7 +112,8 @@ void storeInteractionData(__local ushort2* buffer, int numValid, __local unsigne
if (baseIndex+numValid <= maxTiles) if (baseIndex+numValid <= maxTiles)
for (int i = 0; i < numValid; i++) { for (int i = 0; i < numValid; i++) {
interactingTiles[baseIndex+i] = buffer[i]; interactingTiles[baseIndex+i] = buffer[i];
interactionFlags[baseIndex+i] = flagsBuffer[i]; interactionFlags[2*(baseIndex+i)] = flagsBuffer[2*i];
interactionFlags[2*(baseIndex+i)+1] = flagsBuffer[2*i+1];
} }
} }
...@@ -111,9 +124,7 @@ void storeInteractionData(__local ushort2* buffer, int numValid, __local unsigne ...@@ -111,9 +124,7 @@ void storeInteractionData(__local ushort2* buffer, int numValid, __local unsigne
__kernel void findBlocksWithInteractions(float cutoffSquared, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global float4* blockCenter, __kernel void findBlocksWithInteractions(float cutoffSquared, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global float4* blockCenter,
__global float4* blockBoundingBox, __global unsigned int* interactionCount, __global ushort2* interactingTiles, __global float4* blockBoundingBox, __global unsigned int* interactionCount, __global ushort2* interactingTiles,
__global unsigned int* interactionFlags, __global float4* posq, unsigned int maxTiles) { __global unsigned int* interactionFlags, __global float4* posq, unsigned int maxTiles) {
__local ushort2 buffer[BUFFER_SIZE]; ushort2 buffer[BUFFER_SIZE];
__local unsigned int flagsBuffer[BUFFER_SIZE];
__local float4 temp[TILE_SIZE];
int valuesInBuffer = 0; int valuesInBuffer = 0;
const int numTiles = (NUM_BLOCKS*(NUM_BLOCKS+1))/2; const int numTiles = (NUM_BLOCKS*(NUM_BLOCKS+1))/2;
unsigned int start = get_group_id(0)*numTiles/get_num_groups(0); unsigned int start = get_group_id(0)*numTiles/get_num_groups(0);
...@@ -146,10 +157,10 @@ __kernel void findBlocksWithInteractions(float cutoffSquared, float4 periodicBox ...@@ -146,10 +157,10 @@ __kernel void findBlocksWithInteractions(float cutoffSquared, float4 periodicBox
buffer[valuesInBuffer++] = (ushort2) (x, y); buffer[valuesInBuffer++] = (ushort2) (x, y);
if (valuesInBuffer == BUFFER_SIZE) { if (valuesInBuffer == BUFFER_SIZE) {
storeInteractionData(buffer, valuesInBuffer, flagsBuffer, temp, interactionCount, interactingTiles, interactionFlags, cutoffSquared, periodicBoxSize, invPeriodicBoxSize, posq, blockCenter, blockBoundingBox, maxTiles); storeInteractionData(buffer, valuesInBuffer, interactionCount, interactingTiles, interactionFlags, cutoffSquared, periodicBoxSize, invPeriodicBoxSize, posq, blockCenter, blockBoundingBox, maxTiles);
valuesInBuffer = 0; valuesInBuffer = 0;
} }
} }
} }
storeInteractionData(buffer, valuesInBuffer, flagsBuffer, temp, interactionCount, interactingTiles, interactionFlags, cutoffSquared, periodicBoxSize, invPeriodicBoxSize, posq, blockCenter, blockBoundingBox, maxTiles); storeInteractionData(buffer, valuesInBuffer, interactionCount, interactingTiles, interactionFlags, cutoffSquared, periodicBoxSize, invPeriodicBoxSize, posq, blockCenter, blockBoundingBox, maxTiles);
} }
...@@ -49,7 +49,6 @@ __kernel void computeNonbonded(__global float4* forceBuffers, __global float* en ...@@ -49,7 +49,6 @@ __kernel void computeNonbonded(__global float4* forceBuffers, __global float* en
x = (pos-y*NUM_BLOCKS+y*(y+1)/2); x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
} }
} }
unsigned int tgx = get_local_id(0) & (TILE_SIZE-1);
// Locate the exclusion data for this tile. // Locate the exclusion data for this tile.
...@@ -92,15 +91,14 @@ __kernel void computeNonbonded(__global float4* forceBuffers, __global float* en ...@@ -92,15 +91,14 @@ __kernel void computeNonbonded(__global float4* forceBuffers, __global float* en
for (unsigned int j = 0; j < TILE_SIZE; j++) { for (unsigned int j = 0; j < TILE_SIZE; j++) {
#ifdef USE_EXCLUSIONS #ifdef USE_EXCLUSIONS
bool isExcluded = !(excl & 0x1); bool isExcluded = !(excl & 0x1);
if (!isExcluded) {
#endif #endif
float4 posq2 = (float4) (localData[j].x, localData[j].y, localData[j].z, localData[j].q); float4 posq2 = (float4) (localData[j].x, localData[j].y, localData[j].z, localData[j].q);
float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f); float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif #endif
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z; float r2 = dot(delta.xyz, delta.xyz);
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
if (r2 < CUTOFF_SQUARED) { if (r2 < CUTOFF_SQUARED) {
#endif #endif
...@@ -125,6 +123,9 @@ __kernel void computeNonbonded(__global float4* forceBuffers, __global float* en ...@@ -125,6 +123,9 @@ __kernel void computeNonbonded(__global float4* forceBuffers, __global float* en
#endif #endif
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
} }
#endif
#ifdef USE_EXCLUSIONS
}
#endif #endif
excl >>= 1; excl >>= 1;
} }
...@@ -144,62 +145,54 @@ __kernel void computeNonbonded(__global float4* forceBuffers, __global float* en ...@@ -144,62 +145,54 @@ __kernel void computeNonbonded(__global float4* forceBuffers, __global float* en
localData[tgx].fz = 0.0f; localData[tgx].fz = 0.0f;
} }
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
unsigned int flags = (numTiles <= maxTiles ? interactionFlags[pos] : 0xFFFFFFFF); unsigned int flags1 = (numTiles <= maxTiles ? interactionFlags[2*pos] : 0xFFFFFFFF);
if (!hasExclusions && flags != 0xFFFFFFFF) { unsigned int flags2 = (numTiles <= maxTiles ? interactionFlags[2*pos+1] : 0xFFFFFFFF);
if (flags == 0) { if (!hasExclusions && (flags1 != 0xFFFFFFFF || flags2 != 0xFFFFFFFF)) {
// No interactions in this tile. // Compute only a subset of the interactions in this tile.
}
else {
// Compute only a subset of the interactions in this tile.
for (unsigned int tgx = 0; tgx < TILE_SIZE; tgx++) { for (unsigned int tgx = 0; tgx < TILE_SIZE; tgx++) {
if ((flags2&(1<<tgx)) != 0) {
unsigned int atom1 = x*TILE_SIZE+tgx; unsigned int atom1 = x*TILE_SIZE+tgx;
float4 force = 0.0f; float4 force = 0.0f;
float4 posq1 = posq[atom1]; float4 posq1 = posq[atom1];
LOAD_ATOM1_PARAMETERS LOAD_ATOM1_PARAMETERS
for (unsigned int j = 0; j < TILE_SIZE; j++) { for (unsigned int j = 0; j < TILE_SIZE; j++) {
if ((flags&(1<<j)) != 0) { if ((flags1&(1<<j)) != 0) {
bool isExcluded = false; bool isExcluded = false;
float4 posq2 = (float4) (localData[j].x, localData[j].y, localData[j].z, localData[j].q); float4 posq2 = (float4) (localData[j].x, localData[j].y, localData[j].z, localData[j].q);
float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f); float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif #endif
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z; float r2 = dot(delta.xyz, delta.xyz);
#ifdef USE_CUTOFF
if (r2 < CUTOFF_SQUARED) { if (r2 < CUTOFF_SQUARED) {
#endif float invR = RSQRT(r2);
float invR = RSQRT(r2); float r = RECIP(invR);
float r = RECIP(invR); unsigned int atom2 = j;
unsigned int atom2 = j; LOAD_ATOM2_PARAMETERS
LOAD_ATOM2_PARAMETERS atom2 = y*TILE_SIZE+j;
atom2 = y*TILE_SIZE+j;
#ifdef USE_SYMMETRIC #ifdef USE_SYMMETRIC
float dEdR = 0.0f; float dEdR = 0.0f;
#else #else
float4 dEdR1 = (float4) 0.0f; float4 dEdR1 = (float4) 0.0f;
float4 dEdR2 = (float4) 0.0f; float4 dEdR2 = (float4) 0.0f;
#endif #endif
float tempEnergy = 0.0f; float tempEnergy = 0.0f;
COMPUTE_INTERACTION COMPUTE_INTERACTION
energy += tempEnergy; energy += tempEnergy;
#ifdef USE_SYMMETRIC #ifdef USE_SYMMETRIC
delta.xyz *= dEdR; delta.xyz *= dEdR;
force.xyz -= delta.xyz; force.xyz -= delta.xyz;
localData[j].fx += delta.x; localData[j].fx += delta.x;
localData[j].fy += delta.y; localData[j].fy += delta.y;
localData[j].fz += delta.z; localData[j].fz += delta.z;
#else #else
force.xyz -= dEdR1.xyz; force.xyz -= dEdR1.xyz;
localData[j].fx += dEdR2.x; localData[j].fx += dEdR2.x;
localData[j].fy += dEdR2.y; localData[j].fy += dEdR2.y;
localData[j].fz += dEdR2.z; localData[j].fz += dEdR2.z;
#endif #endif
#ifdef USE_CUTOFF
} }
#endif
} }
} }
...@@ -226,15 +219,14 @@ __kernel void computeNonbonded(__global float4* forceBuffers, __global float* en ...@@ -226,15 +219,14 @@ __kernel void computeNonbonded(__global float4* forceBuffers, __global float* en
for (unsigned int j = 0; j < TILE_SIZE; j++) { for (unsigned int j = 0; j < TILE_SIZE; j++) {
#ifdef USE_EXCLUSIONS #ifdef USE_EXCLUSIONS
bool isExcluded = !(excl & 0x1); bool isExcluded = !(excl & 0x1);
if (!isExcluded) {
#endif #endif
float4 posq2 = (float4) (localData[j].x, localData[j].y, localData[j].z, localData[j].q); float4 posq2 = (float4) (localData[j].x, localData[j].y, localData[j].z, localData[j].q);
float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f); float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
#ifdef USE_PERIODIC #ifdef USE_PERIODIC
delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x; delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif #endif
float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z; float r2 = dot(delta.xyz, delta.xyz);
#ifdef USE_CUTOFF #ifdef USE_CUTOFF
if (r2 < CUTOFF_SQUARED) { if (r2 < CUTOFF_SQUARED) {
#endif #endif
...@@ -268,6 +260,7 @@ __kernel void computeNonbonded(__global float4* forceBuffers, __global float* en ...@@ -268,6 +260,7 @@ __kernel void computeNonbonded(__global float4* forceBuffers, __global float* en
} }
#endif #endif
#ifdef USE_EXCLUSIONS #ifdef USE_EXCLUSIONS
}
excl >>= 1; excl >>= 1;
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment