findInteractingBlocks.cu 16.4 KB
Newer Older
1
2
#define GROUP_SIZE 256
#define BUFFER_GROUPS 2
3
#define BUFFER_SIZE BUFFER_GROUPS*GROUP_SIZE
4
5
#define WARP_SIZE 32
#define INVALID 0xFFFF
6
7
8
9

/**
 * Find a bounding box for the atoms in each block.
 */
10
11
extern "C" __global__ void findBlockBounds(int numAtoms, real4 periodicBoxSize, real4 invPeriodicBoxSize, const real4* __restrict__ posq,
        real4* __restrict__ blockCenter, real4* __restrict__ blockBoundingBox, int* __restrict__ rebuildNeighborList, real2* __restrict__ sortedBlocks) {
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    int index = blockIdx.x*blockDim.x+threadIdx.x;
    int base = index*TILE_SIZE;
    while (base < numAtoms) {
        real4 pos = posq[base];
#ifdef USE_PERIODIC
        pos.x -= floor(pos.x*invPeriodicBoxSize.x)*periodicBoxSize.x;
        pos.y -= floor(pos.y*invPeriodicBoxSize.y)*periodicBoxSize.y;
        pos.z -= floor(pos.z*invPeriodicBoxSize.z)*periodicBoxSize.z;
        real4 firstPoint = pos;
#endif
        real4 minPos = pos;
        real4 maxPos = pos;
        int last = min(base+TILE_SIZE, numAtoms);
        for (int i = base+1; i < last; i++) {
            pos = posq[i];
#ifdef USE_PERIODIC
            pos.x -= floor((pos.x-firstPoint.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
            pos.y -= floor((pos.y-firstPoint.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
            pos.z -= floor((pos.z-firstPoint.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
            minPos = make_real4(min(minPos.x,pos.x), min(minPos.y,pos.y), min(minPos.z,pos.z), 0);
            maxPos = make_real4(max(maxPos.x,pos.x), max(maxPos.y,pos.y), max(maxPos.z,pos.z), 0);
        }
35
36
        real4 blockSize = 0.5f*(maxPos-minPos);
        blockBoundingBox[index] = blockSize;
37
        blockCenter[index] = 0.5f*(maxPos+minPos);
38
        sortedBlocks[index] = make_real2(blockSize.x+blockSize.y+blockSize.z, index);
39
40
41
42
        index += blockDim.x*gridDim.x;
        base = index*TILE_SIZE;
    }
    if (blockIdx.x == 0 && threadIdx.x == 0)
43
        rebuildNeighborList[0] = 0;
44
45
46
}

/**
47
 * Sort the data about bounding boxes so it can be accessed more efficiently in the next kernel.
48
 */
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, const real4* __restrict__ blockCenter,
        const real4* __restrict__ blockBoundingBox, real4* __restrict__ sortedBlockCenter,
        real4* __restrict__ sortedBlockBoundingBox, const real4* __restrict__ posq, const real4* __restrict__ oldPositions,
        unsigned int* __restrict__ interactionCount, int* __restrict__ rebuildNeighborList) {
    for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_BLOCKS; i += blockDim.x*gridDim.x) {
        int index = (int) sortedBlock[i].y;
        sortedBlockCenter[i] = blockCenter[index];
        sortedBlockBoundingBox[i] = blockBoundingBox[index];
    }
    
    // Also check whether any atom has moved enough so that we really need to rebuild the neighbor list.

    bool rebuild = false;
    for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_ATOMS; i += blockDim.x*gridDim.x) {
        real4 delta = oldPositions[i]-posq[i];
        if (delta.x*delta.x + delta.y*delta.y + delta.z*delta.z > 0.25f*PADDING*PADDING)
            rebuild = true;
    }
    if (rebuild) {
        rebuildNeighborList[0] = 1;
        interactionCount[0] = 0;
    }
}
72

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/**
 * Perform a parallel prefix sum over an array.  The input values are all assumed to be 0 or 1.
 */
__device__ void prefixSum(short* sum, ushort2* temp) {
#if __CUDA_ARCH__ >= 300
    const int indexInWarp = threadIdx.x%WARP_SIZE;
    const int warpMask = (2<<indexInWarp)-1;
    for (int base = 0; base < BUFFER_SIZE; base += blockDim.x)
        temp[base+threadIdx.x].x = __popc(__ballot(sum[base+threadIdx.x])&warpMask);
    __syncthreads();
    if (threadIdx.x < BUFFER_SIZE/WARP_SIZE) {
        int multiWarpSum = temp[(threadIdx.x+1)*WARP_SIZE-1].x;
        for (int offset = 1; offset < BUFFER_SIZE/WARP_SIZE; offset *= 2) {
            short n = __shfl_up(multiWarpSum, offset, WARP_SIZE);
            if (indexInWarp >= offset)
                multiWarpSum += n;
        }
        temp[threadIdx.x].y = multiWarpSum;
    }
    __syncthreads();
    for (int i = threadIdx.x; i < BUFFER_SIZE; i += blockDim.x)
        sum[i] = temp[i].x+(i < WARP_SIZE ? 0 : temp[i/WARP_SIZE-1].y);
    __syncthreads();
#else
    for (int i = threadIdx.x; i < BUFFER_SIZE; i += blockDim.x)
        temp[i].x = sum[i];
99
100
101
102
    __syncthreads();
    int whichBuffer = 0;
    for (int offset = 1; offset < BUFFER_SIZE; offset *= 2) {
        if (whichBuffer == 0)
103
            for (int i = threadIdx.x; i < BUFFER_SIZE; i += blockDim.x)
104
105
                temp[i].y = (i < offset ? temp[i].x : temp[i].x+temp[i-offset].x);
        else
106
            for (int i = threadIdx.x; i < BUFFER_SIZE; i += blockDim.x)
107
108
109
110
111
                temp[i].x = (i < offset ? temp[i].y : temp[i].y+temp[i-offset].y);
        whichBuffer = 1-whichBuffer;
        __syncthreads();
    }
    if (whichBuffer == 0)
112
        for (int i = threadIdx.x; i < BUFFER_SIZE; i += blockDim.x)
113
114
            sum[i] = temp[i].x;
    else
115
        for (int i = threadIdx.x; i < BUFFER_SIZE; i += blockDim.x)
116
117
            sum[i] = temp[i].y;
    __syncthreads();
118
119
#endif
}
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
/**
 * This is called by findBlocksWithInteractions().  It compacts the list of blocks, identifies interactions
 * in them, and writes the result to global memory.
 */
__device__ void storeInteractionData(unsigned short x, unsigned short* buffer, short* sum, ushort2* temp, int* atoms, int& numAtoms,
            int& baseIndex, unsigned int* interactionCount, ushort2* interactingTiles, unsigned int* interactingAtoms, real4 periodicBoxSize,
            real4 invPeriodicBoxSize, const real4* posq, real3* posBuffer, real4 blockCenterX, real4 blockSizeX, unsigned int maxTiles, bool finish) {
    const bool singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= PADDED_CUTOFF &&
                                     0.5f*periodicBoxSize.y-blockSizeX.y >= PADDED_CUTOFF &&
                                     0.5f*periodicBoxSize.z-blockSizeX.z >= PADDED_CUTOFF);
    if (threadIdx.x < TILE_SIZE) {
        real3 pos = trimTo3(posq[x*TILE_SIZE+threadIdx.x]);
        posBuffer[threadIdx.x] = pos;
#ifdef USE_PERIODIC
        if (singlePeriodicCopy) {
            // The box is small enough that we can just translate all the atoms into a single periodic
            // box, then skip having to apply periodic boundary conditions later.
            
            pos.x -= floor((pos.x-blockCenterX.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
            pos.y -= floor((pos.y-blockCenterX.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
            pos.z -= floor((pos.z-blockCenterX.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
            posBuffer[threadIdx.x] = pos;
143
        }
144
145
146
147
148
149
150
#endif
    }
    
    // The buffer is full, so we need to compact it and write out results.  Start by doing a parallel prefix sum.

    for (int i = threadIdx.x; i < BUFFER_SIZE; i += blockDim.x)
        sum[i] = (buffer[i] == INVALID ? 0 : 1);
151
    __syncthreads();
152
153
    prefixSum(sum, temp);
    int numValid = sum[BUFFER_SIZE-1];
154

155
    // Compact the buffer.
156

157
158
159
    for (int i = threadIdx.x; i < BUFFER_SIZE; i += blockDim.x)
        if (buffer[i] != INVALID)
            temp[sum[i]-1].x = buffer[i];
160
    __syncthreads();
161
162
    for (int i = threadIdx.x; i < BUFFER_SIZE; i += blockDim.x)
        buffer[i] = temp[i].x;
163
    __syncthreads();
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

    // Loop over the tiles and find specific interactions in them.

    const int indexInWarp = threadIdx.x%WARP_SIZE;
    for (int base = 0; base < numValid; base += BUFFER_SIZE/WARP_SIZE) {
        for (int i = threadIdx.x/WARP_SIZE; i < BUFFER_SIZE/WARP_SIZE && base+i < numValid; i += GROUP_SIZE/WARP_SIZE) {
            // Check each atom in block Y for interactions.
            
            real3 pos = trimTo3(posq[buffer[base+i]*TILE_SIZE+indexInWarp]);
#ifdef USE_PERIODIC
            if (singlePeriodicCopy) {
                pos.x -= floor((pos.x-blockCenterX.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                pos.y -= floor((pos.y-blockCenterX.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                pos.z -= floor((pos.z-blockCenterX.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
            }
#endif
            bool interacts = false;
#ifdef USE_PERIODIC
            if (!singlePeriodicCopy) {
                for (int j = 0; j < TILE_SIZE; j++) {
                    real3 delta = pos-posBuffer[j];
                    delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                    delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                    delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
                    interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED);
                }
            }
            else {
#endif
                for (int j = 0; j < TILE_SIZE; j++) {
                    real3 delta = pos-posBuffer[j];
                    interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED);
                }
#ifdef USE_PERIODIC
            }
#endif
            sum[i*WARP_SIZE+indexInWarp] = (interacts ? 1 : 0);
        }
        for (int i = numValid-base+threadIdx.x/WARP_SIZE; i < BUFFER_SIZE/WARP_SIZE; i += GROUP_SIZE/WARP_SIZE)
            sum[i*WARP_SIZE+indexInWarp] = 0;

        // Compact the list of atoms.

        __syncthreads();
        prefixSum(sum, temp);
        for (int i = threadIdx.x; i < BUFFER_SIZE; i += blockDim.x)
            if (sum[i] != (i == 0 ? 0 : sum[i-1]))
                atoms[numAtoms+sum[i]-1] = buffer[base+i/WARP_SIZE]*TILE_SIZE+indexInWarp;

        // Store them to global memory.

        int atomsToStore = numAtoms+sum[BUFFER_SIZE-1];
        bool storePartialTile = (finish && base >= numValid-BUFFER_SIZE/WARP_SIZE);
        int tilesToStore = (storePartialTile ? (atomsToStore+TILE_SIZE-1)/TILE_SIZE : atomsToStore/TILE_SIZE);
        if (tilesToStore > 0) {
            if (threadIdx.x == 0)
                baseIndex = atomicAdd(interactionCount, tilesToStore);
            __syncthreads();
            if (threadIdx.x == 0)
                numAtoms = atomsToStore-tilesToStore*TILE_SIZE;
            if (baseIndex+tilesToStore <= maxTiles) {
                if (threadIdx.x < tilesToStore)
                    interactingTiles[baseIndex+threadIdx.x] = make_ushort2(x, singlePeriodicCopy);
                for (int i = threadIdx.x; i < tilesToStore*TILE_SIZE; i += blockDim.x)
                    interactingAtoms[baseIndex*TILE_SIZE+i] = (i < atomsToStore ? atoms[i] : NUM_ATOMS);
            }
        }
        else {
            __syncthreads();
            if (threadIdx.x == 0)
                numAtoms += sum[BUFFER_SIZE-1];
        }
        __syncthreads();
        if (threadIdx.x < numAtoms && !storePartialTile)
            atoms[threadIdx.x] = atoms[tilesToStore*TILE_SIZE+threadIdx.x];
    }

    if (numValid == 0 && numAtoms > 0 && finish) {
        // We didn't have any more tiles to process, but there were some atoms left over from a
        // previous call to this function.  Save them now.

        if (threadIdx.x == 0)
            baseIndex = atomicAdd(interactionCount, 1);
        __syncthreads();
        if (baseIndex < maxTiles) {
            if (threadIdx.x == 0)
                interactingTiles[baseIndex] = make_ushort2(x, singlePeriodicCopy);
            if (threadIdx.x < TILE_SIZE)
                interactingAtoms[baseIndex*TILE_SIZE+threadIdx.x] = (threadIdx.x < numAtoms ? atoms[threadIdx.x] : NUM_ATOMS);
        }
    }

    // Reset the buffer for processing more tiles.

    for (int i = threadIdx.x; i < BUFFER_SIZE; i += blockDim.x)
        buffer[i] = INVALID;
260
261
262
263
264
265
266
267
}

/**
 * Compare the bounding boxes for each pair of blocks.  If they are sufficiently far apart,
 * mark them as non-interacting.
 */
extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, const real4* __restrict__ blockCenter,
        const real4* __restrict__ blockBoundingBox, unsigned int* __restrict__ interactionCount, ushort2* __restrict__ interactingTiles,
268
269
270
271
272
        unsigned int* __restrict__ interactingAtoms, const real4* __restrict__ posq, unsigned int maxTiles, unsigned int startBlockIndex,
        unsigned int numBlocks, real2* __restrict__ sortedBlocks, const real4* __restrict__ sortedBlockCenter, const real4* __restrict__ sortedBlockBoundingBox,
        const unsigned int* __restrict__ exclusionIndices, const unsigned int* __restrict__ exclusionRowIndices, real4* __restrict__ oldPositions,
        const int* __restrict__ rebuildNeighborList) {
    __shared__ unsigned short buffer[BUFFER_SIZE];
273
274
    __shared__ short sum[BUFFER_SIZE];
    __shared__ ushort2 temp[BUFFER_SIZE];
275
276
277
    __shared__ int atoms[BUFFER_SIZE+TILE_SIZE];
    __shared__ real3 posBuffer[TILE_SIZE];
    __shared__ int exclusionsForX[MAX_EXCLUSIONS];
278
279
    __shared__ int bufferFull;
    __shared__ int globalIndex;
280
281
282
283
284
    __shared__ int numAtoms;
    
    if (rebuildNeighborList[0] == 0)
        return; // The neighbor list doesn't need to be rebuilt.
    
285
286
287
288
    int valuesInBuffer = 0;
    if (threadIdx.x == 0)
        bufferFull = false;
    for (int i = 0; i < BUFFER_GROUPS; ++i)
289
        buffer[i*GROUP_SIZE+threadIdx.x] = INVALID;
290
    __syncthreads();
291
292
293
294
295
296
297
298
299
300
    
    // Loop over blocks sorted by size.
    
    for (int i = startBlockIndex+blockIdx.x; i < startBlockIndex+numBlocks; i += gridDim.x) {
        if (threadIdx.x == blockDim.x-1)
            numAtoms = 0;
        real2 sortedKey = sortedBlocks[i];
        unsigned short x = (unsigned short) sortedKey.y;
        real4 blockCenterX = blockCenter[x];
        real4 blockSizeX = blockBoundingBox[x];
301

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        // Load exclusion data for block x.
        
        const int exclusionStart = exclusionRowIndices[x];
        const int exclusionEnd = exclusionRowIndices[x+1];
        const int numExclusions = exclusionEnd-exclusionStart;
        for (int j = threadIdx.x; j < numExclusions; j += blockDim.x)
            exclusionsForX[j] = exclusionIndices[exclusionStart+j];
        __syncthreads();
        
        // Compare it to other blocks after this one in sorted order.
        
        for (int base = i+1; base < NUM_BLOCKS; base += blockDim.x) {
            int j = base+threadIdx.x;
            real2 sortedKey2 = (j < NUM_BLOCKS ? sortedBlocks[j] : make_real2(0));
            real4 blockCenterY = (j < NUM_BLOCKS ? sortedBlockCenter[j] : make_real4(0));
            real4 blockSizeY = (j < NUM_BLOCKS ? sortedBlockBoundingBox[j] : make_real4(0));
            unsigned short y = (unsigned short) sortedKey2.y;
            real4 delta = blockCenterX-blockCenterY;
320
321
322
323
324
#ifdef USE_PERIODIC
            delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
            delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
            delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
325
326
327
328
329
330
331
            delta.x = max(0.0f, fabs(delta.x)-blockSizeX.x-blockSizeY.x);
            delta.y = max(0.0f, fabs(delta.y)-blockSizeX.y-blockSizeY.y);
            delta.z = max(0.0f, fabs(delta.z)-blockSizeX.z-blockSizeY.z);
            bool hasExclusions = false;
            for (int k = 0; k < numExclusions; k++)
                hasExclusions |= (exclusionsForX[k] == y);
            if (j < NUM_BLOCKS && delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED && !hasExclusions) {
332
333
334
                // Add this tile to the buffer.

                int bufferIndex = valuesInBuffer*GROUP_SIZE+threadIdx.x;
335
                buffer[bufferIndex] = y;
336
337
338
339
340
                valuesInBuffer++;
                if (!bufferFull && valuesInBuffer == BUFFER_GROUPS)
                    bufferFull = true;
            }
            __syncthreads();
341
342
343
344
345
346
            if (bufferFull) {
                storeInteractionData(x, buffer, sum, temp, atoms, numAtoms, globalIndex, interactionCount, interactingTiles, interactingAtoms, periodicBoxSize, invPeriodicBoxSize, posq, posBuffer, blockCenterX, blockSizeX, maxTiles, false);
                valuesInBuffer = 0;
                if (threadIdx.x == 0)
                    bufferFull = false;
                __syncthreads();
347
348
            }
        }
349
        storeInteractionData(x, buffer, sum, temp, atoms, numAtoms, globalIndex, interactionCount, interactingTiles, interactingAtoms, periodicBoxSize, invPeriodicBoxSize, posq, posBuffer, blockCenterX, blockSizeX, maxTiles, true);
350
    }
351
352
353
354
355
    
    // Record the positions the neighbor list is based on.
    
    for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_ATOMS; i += blockDim.x*gridDim.x)
        oldPositions[i] = posq[i];
356
}