findInteractingBlocks.cu 18.7 KB
Newer Older
1
#define GROUP_SIZE 256
2
#define BUFFER_SIZE 256
3
4
5
6

/**
 * Find a bounding box for the atoms in each block.
 */
7
8
9
extern "C" __global__ void findBlockBounds(int numAtoms, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
        const real4* __restrict__ posq, real4* __restrict__ blockCenter, real4* __restrict__ blockBoundingBox, int* __restrict__ rebuildNeighborList,
        real2* __restrict__ sortedBlocks) {
10
11
12
13
14
    int index = blockIdx.x*blockDim.x+threadIdx.x;
    int base = index*TILE_SIZE;
    while (base < numAtoms) {
        real4 pos = posq[base];
#ifdef USE_PERIODIC
15
        APPLY_PERIODIC_TO_POS(pos)
16
17
18
19
20
21
22
#endif
        real4 minPos = pos;
        real4 maxPos = pos;
        int last = min(base+TILE_SIZE, numAtoms);
        for (int i = base+1; i < last; i++) {
            pos = posq[i];
#ifdef USE_PERIODIC
23
            real4 center = 0.5f*(maxPos+minPos);
24
            APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)
25
26
27
28
#endif
            minPos = make_real4(min(minPos.x,pos.x), min(minPos.y,pos.y), min(minPos.z,pos.z), 0);
            maxPos = make_real4(max(maxPos.x,pos.x), max(maxPos.y,pos.y), max(maxPos.z,pos.z), 0);
        }
29
        real4 blockSize = 0.5f*(maxPos-minPos);
Peter Eastman's avatar
Peter Eastman committed
30
        real4 center = 0.5f*(maxPos+minPos);
31
        center.w = 0;
32
        for (int i = base; i < last; i++) {
Peter Eastman's avatar
Peter Eastman committed
33
34
35
36
37
            pos = posq[i];
            real4 delta = posq[i]-center;
#ifdef USE_PERIODIC
            APPLY_PERIODIC_TO_DELTA(delta)
#endif
38
            center.w = max(center.w, delta.x*delta.x+delta.y*delta.y+delta.z*delta.z);
Peter Eastman's avatar
Peter Eastman committed
39
        }
40
        center.w = sqrt(center.w);
41
        blockBoundingBox[index] = blockSize;
Peter Eastman's avatar
Peter Eastman committed
42
        blockCenter[index] = center;
43
        sortedBlocks[index] = make_real2(blockSize.x+blockSize.y+blockSize.z, index);
44
45
46
47
        index += blockDim.x*gridDim.x;
        base = index*TILE_SIZE;
    }
    if (blockIdx.x == 0 && threadIdx.x == 0)
48
        rebuildNeighborList[0] = 0;
49
50
51
}

/**
52
 * Sort the data about bounding boxes so it can be accessed more efficiently in the next kernel.
53
 */
54
55
56
extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, const real4* __restrict__ blockCenter,
        const real4* __restrict__ blockBoundingBox, real4* __restrict__ sortedBlockCenter,
        real4* __restrict__ sortedBlockBoundingBox, const real4* __restrict__ posq, const real4* __restrict__ oldPositions,
57
        unsigned int* __restrict__ interactionCount, int* __restrict__ rebuildNeighborList, bool forceRebuild) {
58
59
60
61
62
63
64
65
    for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_BLOCKS; i += blockDim.x*gridDim.x) {
        int index = (int) sortedBlock[i].y;
        sortedBlockCenter[i] = blockCenter[index];
        sortedBlockBoundingBox[i] = blockBoundingBox[index];
    }
    
    // Also check whether any atom has moved enough so that we really need to rebuild the neighbor list.

66
    bool rebuild = forceRebuild;
67
68
69
70
71
72
73
74
    for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_ATOMS; i += blockDim.x*gridDim.x) {
        real4 delta = oldPositions[i]-posq[i];
        if (delta.x*delta.x + delta.y*delta.y + delta.z*delta.z > 0.25f*PADDING*PADDING)
            rebuild = true;
    }
    if (rebuild) {
        rebuildNeighborList[0] = 1;
        interactionCount[0] = 0;
75
        interactionCount[1] = 0;
76
77
    }
}
78

79
80
81
__device__ int saveSinglePairs(int x, int* atoms, int* flags, int length, unsigned int maxSinglePairs, unsigned int* singlePairCount, int2* singlePairs, int* sumBuffer, volatile int& pairStartIndex) {
    // Record interactions that should be computed as single pairs rather than in blocks.
    
82
83
84
85
    const int indexInWarp = threadIdx.x%32;
    int sum = 0;
    for (int i = indexInWarp; i < length; i += 32) {
        int count = __popc(flags[i]);
86
        sum += (count <= MAX_BITS_FOR_PAIRS ? count : 0);
87
88
89
90
91
92
93
94
95
96
97
98
    }
    sumBuffer[indexInWarp] = sum;
    for (int step = 1; step < 32; step *= 2) {
        int add = (indexInWarp >= step ? sumBuffer[indexInWarp-step] : 0);
        sumBuffer[indexInWarp] += add;
    }
    int pairsToStore = sumBuffer[31];
    if (indexInWarp == 0)
        pairStartIndex = atomicAdd(singlePairCount, pairsToStore);
    int pairIndex = pairStartIndex + (indexInWarp > 0 ? sumBuffer[indexInWarp-1] : 0);
    for (int i = indexInWarp; i < length; i += 32) {
        int count = __popc(flags[i]);
99
        if (count <= MAX_BITS_FOR_PAIRS && pairIndex+count < maxSinglePairs) {
100
101
102
103
104
105
106
107
108
            int f = flags[i];
            while (f != 0) {
                singlePairs[pairIndex] = make_int2(atoms[i], x*TILE_SIZE+__ffs(f)-1);
                f &= f-1;
                pairIndex++;
            }
        }
    }
    
109
    // Compact the remaining interactions.
110
111
112
    
    const int warpMask = (1<<indexInWarp)-1;
    int numCompacted = 0;
113
114
    for (int start = 0; start < length; start += 32) {
        int i = start+indexInWarp;
115
116
        int atom = atoms[i];
        int flag = flags[i];
117
        bool include = (i < length && __popc(flags[i]) > MAX_BITS_FOR_PAIRS);
118
119
120
121
122
123
124
125
126
127
128
        int includeFlags = __ballot(include);
        if (include) {
            int index = numCompacted+__popc(includeFlags&warpMask);
            atoms[index] = atom;
            flags[index] = flag;
        }
        numCompacted += __popc(includeFlags);
    }
    return numCompacted;
}

129
/**
130
131
132
133
134
 * Compare the bounding boxes for each pair of atom blocks (comprised of 32 atoms each), forming a tile. If the two
 * atom blocks are sufficiently far apart, mark them as non-interacting. There are two stages in the algorithm.
 *
 * STAGE 1:
 *
135
 * A coarse grained atom block against interacting atom block neighbour list is constructed. 
136
 *
137
138
139
140
 * Each warp first loads in some block X of interest. Each thread within the warp then loads 
 * in a different atom block Y. If Y has exclusions with X, then Y is not processed.  If the bounding boxes 
 * of the two atom blocks are within the cutoff distance, then the two atom blocks are considered to be
 * interacting and Y is added to the buffer for X.
141
142
143
 *
 * STAGE 2:
 *
144
 * A fine grained atom block against interacting atoms neighbour list is constructed.
145
 *
146
147
148
149
 * The warp loops over atom blocks Y that were found to (possibly) interact with atom block X.  Each thread
 * in the warp loops over the 32 atoms in X and compares their positions to one particular atom from block Y.
 * If it finds one closer than the cutoff distance, the atom is added to the list of atoms interacting with block X.
 * This continues until the buffer fills up, at which point the results are written to global memory.
150
151
152
153
154
155
 *
 * [in] periodicBoxSize        - size of the rectangular periodic box
 * [in] invPeriodicBoxSize     - inverse of the periodic box
 * [in] blockCenter            - the center of each bounding box
 * [in] blockBoundingBox       - bounding box of each atom block
 * [out] interactionCount      - total number of tiles that have interactions
156
 * [out] interactingTiles      - set of blocks that have interactions
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
 * [out] interactingAtoms      - a list of atoms that interact with each atom block
 * [in] posq                   - x,y,z coordinates of each atom and charge q
 * [in] maxTiles               - maximum number of tiles to process, used for multi-GPUs
 * [in] startBlockIndex        - first block to process, used for multi-GPUs,
 * [in] numBlocks              - total number of atom blocks
 * [in] sortedBlocks           - a sorted list of atom blocks based on volume
 * [in] sortedBlockCenter      - sorted centers, duplicated for fast access to avoid indexing
 * [in] sortedBlockBoundingBox - sorted bounding boxes, duplicated for fast access
 * [in] exclusionIndices       - maps into exclusionRowIndices with the starting position for a given atom
 * [in] exclusionRowIndices    - stores the a continuous list of exclusions
 *           eg: block 0 is excluded from atom 3,5,6
 *               block 1 is excluded from atom 3,4
 *               block 2 is excluded from atom 1,3,5,6
 *              exclusionIndices[0][3][5][8]
 *           exclusionRowIndices[3][5][6][3][4][1][3][5][6]
 *                         index 0  1  2  3  4  5  6  7  8 
 * [out] oldPos                - stores the positions of the atoms in which this neighbourlist was built on
 *                             - this is used to decide when to rebuild a neighbourlist
 * [in] rebuildNeighbourList   - whether or not to execute this kernel
 *
177
 */
178
extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
179
        unsigned int* __restrict__ interactionCount, int* __restrict__ interactingTiles, unsigned int* __restrict__ interactingAtoms,
180
181
        int2* __restrict__ singlePairs, const real4* __restrict__ posq, unsigned int maxTiles, unsigned int maxSinglePairs,
        unsigned int startBlockIndex, unsigned int numBlocks, real2* __restrict__ sortedBlocks, const real4* __restrict__ sortedBlockCenter,
182
183
        const real4* __restrict__ sortedBlockBoundingBox, const unsigned int* __restrict__ exclusionIndices, const unsigned int* __restrict__ exclusionRowIndices,
        real4* __restrict__ oldPositions, const int* __restrict__ rebuildNeighborList) {
184

185
186
    if (rebuildNeighborList[0] == 0)
        return; // The neighbor list doesn't need to be rebuilt.
187
188
189
190
191
192
193

    const int indexInWarp = threadIdx.x%32;
    const int warpStart = threadIdx.x-indexInWarp;
    const int totalWarps = blockDim.x*gridDim.x/32;
    const int warpIndex = (blockIdx.x*blockDim.x+threadIdx.x)/32;
    const int warpMask = (1<<indexInWarp)-1;
    __shared__ int workgroupBuffer[BUFFER_SIZE*(GROUP_SIZE/32)];
194
    __shared__ int workgroupFlagsBuffer[BUFFER_SIZE*(GROUP_SIZE/32)];
195
196
197
    __shared__ int warpExclusions[MAX_EXCLUSIONS*(GROUP_SIZE/32)];
    __shared__ real3 posBuffer[GROUP_SIZE];
    __shared__ volatile int workgroupTileIndex[GROUP_SIZE/32];
198
199
    __shared__ int sumBuffer[GROUP_SIZE];
    __shared__ int worksgroupPairStartIndex[GROUP_SIZE/32];
200
    int* buffer = workgroupBuffer+BUFFER_SIZE*(warpStart/32);
201
    int* flagsBuffer = workgroupFlagsBuffer+BUFFER_SIZE*(warpStart/32);
202
203
    int* exclusionsForX = warpExclusions+MAX_EXCLUSIONS*(warpStart/32);
    volatile int& tileStartIndex = workgroupTileIndex[warpStart/32];
204
    volatile int& pairStartIndex = worksgroupPairStartIndex[warpStart/32];
205
206

    // Loop over blocks.
207
    
208
209
210
211
    for (int block1 = startBlockIndex+warpIndex; block1 < startBlockIndex+numBlocks; block1 += totalWarps) {
        // Load data for this block.  Note that all threads in a warp are processing the same block.
        
        real2 sortedKey = sortedBlocks[block1];
212
        int x = (int) sortedKey.y;
213
214
215
216
217
218
219
220
221
222
223
224
        real4 blockCenterX = sortedBlockCenter[block1];
        real4 blockSizeX = sortedBlockBoundingBox[block1];
        int neighborsInBuffer = 0;
        real3 pos1 = trimTo3(posq[x*TILE_SIZE+indexInWarp]);
#ifdef USE_PERIODIC
        const bool singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= PADDED_CUTOFF &&
                                         0.5f*periodicBoxSize.y-blockSizeX.y >= PADDED_CUTOFF &&
                                         0.5f*periodicBoxSize.z-blockSizeX.z >= PADDED_CUTOFF);
        if (singlePeriodicCopy) {
            // The box is small enough that we can just translate all the atoms into a single periodic
            // box, then skip having to apply periodic boundary conditions later.
            
225
            APPLY_PERIODIC_TO_POS_WITH_CENTER(pos1, blockCenterX)
226
227
228
        }
#endif
        posBuffer[threadIdx.x] = pos1;
229

230
231
232
233
234
        // Load exclusion data for block x.
        
        const int exclusionStart = exclusionRowIndices[x];
        const int exclusionEnd = exclusionRowIndices[x+1];
        const int numExclusions = exclusionEnd-exclusionStart;
235
        for (int j = indexInWarp; j < numExclusions; j += 32)
236
            exclusionsForX[j] = exclusionIndices[exclusionStart+j];
237
238
        if (MAX_EXCLUSIONS > 32)
            __syncthreads();
239
        
240
241
242
243
244
245
246
        // Loop over atom blocks to search for neighbors.  The threads in a warp compare block1 against 32
        // other blocks in parallel.

        for (int block2Base = block1+1; block2Base < NUM_BLOCKS; block2Base += 32) {
            int block2 = block2Base+indexInWarp;
            bool includeBlock2 = (block2 < NUM_BLOCKS);
            if (includeBlock2) {
247
248
                real4 blockCenterY = sortedBlockCenter[block2];
                real4 blockSizeY = sortedBlockBoundingBox[block2];
249
                real4 blockDelta = blockCenterX-blockCenterY;
250
#ifdef USE_PERIODIC
251
                APPLY_PERIODIC_TO_DELTA(blockDelta)
252
#endif
253
                includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < (PADDED_CUTOFF+blockCenterX.w+blockCenterY.w)*(PADDED_CUTOFF+blockCenterX.w+blockCenterY.w));
254
255
256
257
                blockDelta.x = max(0.0f, fabs(blockDelta.x)-blockSizeX.x-blockSizeY.x);
                blockDelta.y = max(0.0f, fabs(blockDelta.y)-blockSizeX.y-blockSizeY.y);
                blockDelta.z = max(0.0f, fabs(blockDelta.z)-blockSizeX.z-blockSizeY.z);
                includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < PADDED_CUTOFF_SQUARED);
258
259
260
261
262
263
264
#ifdef TRICLINIC
                // The calculation to find the nearest periodic copy is only guaranteed to work if the nearest copy is less than half a box width away.
                // If there's any possibility we might have missed it, do a detailed check.

                if (periodicBoxSize.z/2-blockSizeX.z-blockSizeY.z < PADDED_CUTOFF || periodicBoxSize.y/2-blockSizeX.y-blockSizeY.y < PADDED_CUTOFF)
                    includeBlock2 = true;
#endif
265
266
267
268
269
270
271
272
273
274
275
276
277
278
                if (includeBlock2) {
                    unsigned short y = (unsigned short) sortedBlocks[block2].y;
                    for (int k = 0; k < numExclusions; k++)
                        includeBlock2 &= (exclusionsForX[k] != y);
                }
            }
            
            // Loop over any blocks we identified as potentially containing neighbors.
            
            int includeBlockFlags = __ballot(includeBlock2);
            while (includeBlockFlags != 0) {
                int i = __ffs(includeBlockFlags)-1;
                includeBlockFlags &= includeBlockFlags-1;
                unsigned short y = (unsigned short) sortedBlocks[block2Base+i].y;
279

280
                // Check each atom in block Y for interactions.
281

282
                int atom2 = y*TILE_SIZE+indexInWarp;
283
284
285
                real3 pos2 = trimTo3(posq[atom2]);
#ifdef USE_PERIODIC
                if (singlePeriodicCopy) {
286
                    APPLY_PERIODIC_TO_POS_WITH_CENTER(pos2, blockCenterX)
287
288
                }
#endif
289
290
291
292
293
294
                real4 blockCenterY = sortedBlockCenter[block2Base+i];
                real3 atomDelta = posBuffer[warpStart+indexInWarp]-trimTo3(blockCenterY);
#ifdef USE_PERIODIC
                APPLY_PERIODIC_TO_DELTA(atomDelta)
#endif
                int atomFlags = ballot(atomDelta.x*atomDelta.x+atomDelta.y*atomDelta.y+atomDelta.z*atomDelta.z < (PADDED_CUTOFF+blockCenterY.w)*(PADDED_CUTOFF+blockCenterY.w));
295
                int interacts = 0;
296
297
298
                if (atom2 < NUM_ATOMS && atomFlags != 0) {
                    int first = __ffs(atomFlags)-1;
                    int last = 32-__clz(atomFlags);
299
300
#ifdef USE_PERIODIC
                    if (!singlePeriodicCopy) {
301
                        for (int j = first; j < last; j++) {
302
                            real3 delta = pos2-posBuffer[warpStart+j];
303
                            APPLY_PERIODIC_TO_DELTA(delta)
304
                            interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED ? 1<<j : 0);
305
306
307
308
                        }
                    }
                    else {
#endif
309
                        for (int j = first; j < last; j++) {
310
                            real3 delta = pos2-posBuffer[warpStart+j];
311
                            interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED ? 1<<j : 0);
312
313
314
315
316
317
318
319
320
                        }
#ifdef USE_PERIODIC
                    }
#endif
                }
                
                // Add any interacting atoms to the buffer.
                
                int includeAtomFlags = __ballot(interacts);
321
322
323
324
325
                if (interacts) {
                    int index = neighborsInBuffer+__popc(includeAtomFlags&warpMask);
                    buffer[index] = atom2;
                    flagsBuffer[index] = interacts;
                }
326
327
328
329
                neighborsInBuffer += __popc(includeAtomFlags);
                if (neighborsInBuffer > BUFFER_SIZE-TILE_SIZE) {
                    // Store the new tiles to memory.
                    
330
331
332
#if MAX_BITS_FOR_PAIRS > 0
                    neighborsInBuffer = saveSinglePairs(x, buffer, flagsBuffer, neighborsInBuffer, maxSinglePairs, &interactionCount[1], singlePairs, sumBuffer+warpStart, pairStartIndex);
#endif
333
                    int tilesToStore = neighborsInBuffer/TILE_SIZE;
334
335
                    if (tilesToStore > 0) {
                        if (indexInWarp == 0)
336
                            tileStartIndex = atomicAdd(&interactionCount[0], tilesToStore);
337
338
339
340
341
342
343
344
345
                        int newTileStartIndex = tileStartIndex;
                        if (newTileStartIndex+tilesToStore <= maxTiles) {
                            if (indexInWarp < tilesToStore)
                                interactingTiles[newTileStartIndex+indexInWarp] = x;
                            for (int j = 0; j < tilesToStore; j++)
                                interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = buffer[indexInWarp+j*TILE_SIZE];
                        }
                        buffer[indexInWarp] = buffer[indexInWarp+TILE_SIZE*tilesToStore];
                        neighborsInBuffer -= TILE_SIZE*tilesToStore;
346
347
                    }
                }
348
            }
349
350
351
352
        }
        
        // If we have a partially filled buffer,  store it to memory.
        
353
#if MAX_BITS_FOR_PAIRS > 0
354
        if (neighborsInBuffer > 32)
355
356
            neighborsInBuffer = saveSinglePairs(x, buffer, flagsBuffer, neighborsInBuffer, maxSinglePairs, &interactionCount[1], singlePairs, sumBuffer+warpStart, pairStartIndex);
#endif
357
358
359
        if (neighborsInBuffer > 0) {
            int tilesToStore = (neighborsInBuffer+TILE_SIZE-1)/TILE_SIZE;
            if (indexInWarp == 0)
360
                tileStartIndex = atomicAdd(&interactionCount[0], tilesToStore);
361
            int newTileStartIndex = tileStartIndex;
362
            if (newTileStartIndex+tilesToStore <= maxTiles) {
363
364
365
366
                if (indexInWarp < tilesToStore)
                    interactingTiles[newTileStartIndex+indexInWarp] = x;
                for (int j = 0; j < tilesToStore; j++)
                    interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = (indexInWarp+j*TILE_SIZE < neighborsInBuffer ? buffer[indexInWarp+j*TILE_SIZE] : NUM_ATOMS);
367
368
369
            }
        }
    }
370
371
372
373
374
    
    // Record the positions the neighbor list is based on.
    
    for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_ATOMS; i += blockDim.x*gridDim.x)
        oldPositions[i] = posq[i];
375
}