"example/mnist/mnist_ptq.py" did not exist on "1f6deed697dd19e11ebce619f51c70164da8e95e"
findInteractingBlocks.cu 18.3 KB
Newer Older
1
#define GROUP_SIZE 256
2
#define BUFFER_SIZE 256
3
4
5
6

/**
 * Find a bounding box for the atoms in each block.
 */
7
8
9
extern "C" __global__ void findBlockBounds(int numAtoms, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
        const real4* __restrict__ posq, real4* __restrict__ blockCenter, real4* __restrict__ blockBoundingBox, int* __restrict__ rebuildNeighborList,
        real2* __restrict__ sortedBlocks) {
10
11
12
13
14
    int index = blockIdx.x*blockDim.x+threadIdx.x;
    int base = index*TILE_SIZE;
    while (base < numAtoms) {
        real4 pos = posq[base];
#ifdef USE_PERIODIC
15
        APPLY_PERIODIC_TO_POS(pos)
16
17
18
19
20
21
22
#endif
        real4 minPos = pos;
        real4 maxPos = pos;
        int last = min(base+TILE_SIZE, numAtoms);
        for (int i = base+1; i < last; i++) {
            pos = posq[i];
#ifdef USE_PERIODIC
23
            real4 center = 0.5f*(maxPos+minPos);
24
            APPLY_PERIODIC_TO_POS_WITH_CENTER(pos, center)
25
26
27
28
#endif
            minPos = make_real4(min(minPos.x,pos.x), min(minPos.y,pos.y), min(minPos.z,pos.z), 0);
            maxPos = make_real4(max(maxPos.x,pos.x), max(maxPos.y,pos.y), max(maxPos.z,pos.z), 0);
        }
29
        real4 blockSize = 0.5f*(maxPos-minPos);
Peter Eastman's avatar
Peter Eastman committed
30
        real4 center = 0.5f*(maxPos+minPos);
31
        center.w = 0;
Peter Eastman's avatar
Peter Eastman committed
32
33
34
35
36
37
        for (int i = base+1; i < last; i++) {
            pos = posq[i];
            real4 delta = posq[i]-center;
#ifdef USE_PERIODIC
            APPLY_PERIODIC_TO_DELTA(delta)
#endif
38
            center.w = max(center.w, delta.x*delta.x+delta.y*delta.y+delta.z*delta.z);
Peter Eastman's avatar
Peter Eastman committed
39
        }
40
        blockBoundingBox[index] = blockSize;
Peter Eastman's avatar
Peter Eastman committed
41
        blockCenter[index] = center;
42
        sortedBlocks[index] = make_real2(blockSize.x+blockSize.y+blockSize.z, index);
43
44
45
46
        index += blockDim.x*gridDim.x;
        base = index*TILE_SIZE;
    }
    if (blockIdx.x == 0 && threadIdx.x == 0)
47
        rebuildNeighborList[0] = 0;
48
49
50
}

/**
51
 * Sort the data about bounding boxes so it can be accessed more efficiently in the next kernel.
52
 */
53
54
55
extern "C" __global__ void sortBoxData(const real2* __restrict__ sortedBlock, const real4* __restrict__ blockCenter,
        const real4* __restrict__ blockBoundingBox, real4* __restrict__ sortedBlockCenter,
        real4* __restrict__ sortedBlockBoundingBox, const real4* __restrict__ posq, const real4* __restrict__ oldPositions,
56
        unsigned int* __restrict__ interactionCount, unsigned int* __restrict__ singlePairCount, int* __restrict__ rebuildNeighborList, bool forceRebuild) {
57
58
59
60
61
62
63
64
    for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_BLOCKS; i += blockDim.x*gridDim.x) {
        int index = (int) sortedBlock[i].y;
        sortedBlockCenter[i] = blockCenter[index];
        sortedBlockBoundingBox[i] = blockBoundingBox[index];
    }
    
    // Also check whether any atom has moved enough so that we really need to rebuild the neighbor list.

65
    bool rebuild = forceRebuild;
66
67
68
69
70
71
72
73
    for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_ATOMS; i += blockDim.x*gridDim.x) {
        real4 delta = oldPositions[i]-posq[i];
        if (delta.x*delta.x + delta.y*delta.y + delta.z*delta.z > 0.25f*PADDING*PADDING)
            rebuild = true;
    }
    if (rebuild) {
        rebuildNeighborList[0] = 1;
        interactionCount[0] = 0;
74
        singlePairCount[0] = 0;
75
76
    }
}
77

78
79
80
__device__ int saveSinglePairs(int x, int* atoms, int* flags, int length, unsigned int maxSinglePairs, unsigned int* singlePairCount, int2* singlePairs, int* sumBuffer, volatile int& pairStartIndex) {
    // Record interactions that should be computed as single pairs rather than in blocks.
    
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    const int indexInWarp = threadIdx.x%32;
    const int maxBitsForPairs = 2;
    int sum = 0;
    for (int i = indexInWarp; i < length; i += 32) {
        int count = __popc(flags[i]);
        sum += (count <= maxBitsForPairs ? count : 0);
    }
    sumBuffer[indexInWarp] = sum;
    for (int step = 1; step < 32; step *= 2) {
        int add = (indexInWarp >= step ? sumBuffer[indexInWarp-step] : 0);
        sumBuffer[indexInWarp] += add;
    }
    int pairsToStore = sumBuffer[31];
    if (indexInWarp == 0)
        pairStartIndex = atomicAdd(singlePairCount, pairsToStore);
    int pairIndex = pairStartIndex + (indexInWarp > 0 ? sumBuffer[indexInWarp-1] : 0);
    for (int i = indexInWarp; i < length; i += 32) {
        int count = __popc(flags[i]);
        if (count <= maxBitsForPairs && pairIndex+count < maxSinglePairs) {
            int f = flags[i];
            while (f != 0) {
                singlePairs[pairIndex] = make_int2(atoms[i], x*TILE_SIZE+__ffs(f)-1);
                f &= f-1;
                pairIndex++;
            }
        }
    }
    
109
    // Compact the remaining interactions.
110
111
112
    
    const int warpMask = (1<<indexInWarp)-1;
    int numCompacted = 0;
113
114
    for (int start = 0; start < length; start += 32) {
        int i = start+indexInWarp;
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        int atom = atoms[i];
        int flag = flags[i];
        bool include = (i < length && __popc(flags[i]) > maxBitsForPairs);
        int includeFlags = __ballot(include);
        if (include) {
            int index = numCompacted+__popc(includeFlags&warpMask);
            atoms[index] = atom;
            flags[index] = flag;
        }
        numCompacted += __popc(includeFlags);
    }
    return numCompacted;
}

129
/**
130
131
132
133
134
 * Compare the bounding boxes for each pair of atom blocks (comprised of 32 atoms each), forming a tile. If the two
 * atom blocks are sufficiently far apart, mark them as non-interacting. There are two stages in the algorithm.
 *
 * STAGE 1:
 *
135
 * A coarse grained atom block against interacting atom block neighbour list is constructed. 
136
 *
137
138
139
140
 * Each warp first loads in some block X of interest. Each thread within the warp then loads 
 * in a different atom block Y. If Y has exclusions with X, then Y is not processed.  If the bounding boxes 
 * of the two atom blocks are within the cutoff distance, then the two atom blocks are considered to be
 * interacting and Y is added to the buffer for X.
141
142
143
 *
 * STAGE 2:
 *
144
 * A fine grained atom block against interacting atoms neighbour list is constructed.
145
 *
146
147
148
149
 * The warp loops over atom blocks Y that were found to (possibly) interact with atom block X.  Each thread
 * in the warp loops over the 32 atoms in X and compares their positions to one particular atom from block Y.
 * If it finds one closer than the cutoff distance, the atom is added to the list of atoms interacting with block X.
 * This continues until the buffer fills up, at which point the results are written to global memory.
150
151
152
153
154
155
 *
 * [in] periodicBoxSize        - size of the rectangular periodic box
 * [in] invPeriodicBoxSize     - inverse of the periodic box
 * [in] blockCenter            - the center of each bounding box
 * [in] blockBoundingBox       - bounding box of each atom block
 * [out] interactionCount      - total number of tiles that have interactions
156
 * [out] interactingTiles      - set of blocks that have interactions
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
 * [out] interactingAtoms      - a list of atoms that interact with each atom block
 * [in] posq                   - x,y,z coordinates of each atom and charge q
 * [in] maxTiles               - maximum number of tiles to process, used for multi-GPUs
 * [in] startBlockIndex        - first block to process, used for multi-GPUs,
 * [in] numBlocks              - total number of atom blocks
 * [in] sortedBlocks           - a sorted list of atom blocks based on volume
 * [in] sortedBlockCenter      - sorted centers, duplicated for fast access to avoid indexing
 * [in] sortedBlockBoundingBox - sorted bounding boxes, duplicated for fast access
 * [in] exclusionIndices       - maps into exclusionRowIndices with the starting position for a given atom
 * [in] exclusionRowIndices    - stores the a continuous list of exclusions
 *           eg: block 0 is excluded from atom 3,5,6
 *               block 1 is excluded from atom 3,4
 *               block 2 is excluded from atom 1,3,5,6
 *              exclusionIndices[0][3][5][8]
 *           exclusionRowIndices[3][5][6][3][4][1][3][5][6]
 *                         index 0  1  2  3  4  5  6  7  8 
 * [out] oldPos                - stores the positions of the atoms in which this neighbourlist was built on
 *                             - this is used to decide when to rebuild a neighbourlist
 * [in] rebuildNeighbourList   - whether or not to execute this kernel
 *
177
 */
178
extern "C" __global__ void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ,
179
180
181
        unsigned int* __restrict__ interactionCount, int* __restrict__ interactingTiles, unsigned int* __restrict__ interactingAtoms,
        unsigned int* __restrict__ singlePairCount, int2* __restrict__ singlePairs, const real4* __restrict__ posq, unsigned int maxTiles,
        unsigned int maxSinglePairs, unsigned int startBlockIndex, unsigned int numBlocks, real2* __restrict__ sortedBlocks, const real4* __restrict__ sortedBlockCenter,
182
183
        const real4* __restrict__ sortedBlockBoundingBox, const unsigned int* __restrict__ exclusionIndices, const unsigned int* __restrict__ exclusionRowIndices,
        real4* __restrict__ oldPositions, const int* __restrict__ rebuildNeighborList) {
184

185
186
    if (rebuildNeighborList[0] == 0)
        return; // The neighbor list doesn't need to be rebuilt.
187
188
189
190
191
192
193

    const int indexInWarp = threadIdx.x%32;
    const int warpStart = threadIdx.x-indexInWarp;
    const int totalWarps = blockDim.x*gridDim.x/32;
    const int warpIndex = (blockIdx.x*blockDim.x+threadIdx.x)/32;
    const int warpMask = (1<<indexInWarp)-1;
    __shared__ int workgroupBuffer[BUFFER_SIZE*(GROUP_SIZE/32)];
194
    __shared__ int workgroupFlagsBuffer[BUFFER_SIZE*(GROUP_SIZE/32)];
195
196
197
    __shared__ int warpExclusions[MAX_EXCLUSIONS*(GROUP_SIZE/32)];
    __shared__ real3 posBuffer[GROUP_SIZE];
    __shared__ volatile int workgroupTileIndex[GROUP_SIZE/32];
198
199
    __shared__ int sumBuffer[GROUP_SIZE];
    __shared__ int worksgroupPairStartIndex[GROUP_SIZE/32];
200
    int* buffer = workgroupBuffer+BUFFER_SIZE*(warpStart/32);
201
    int* flagsBuffer = workgroupFlagsBuffer+BUFFER_SIZE*(warpStart/32);
202
203
    int* exclusionsForX = warpExclusions+MAX_EXCLUSIONS*(warpStart/32);
    volatile int& tileStartIndex = workgroupTileIndex[warpStart/32];
204
    volatile int& pairStartIndex = worksgroupPairStartIndex[warpStart/32];
205
206

    // Loop over blocks.
207
    
208
209
210
211
    for (int block1 = startBlockIndex+warpIndex; block1 < startBlockIndex+numBlocks; block1 += totalWarps) {
        // Load data for this block.  Note that all threads in a warp are processing the same block.
        
        real2 sortedKey = sortedBlocks[block1];
212
        int x = (int) sortedKey.y;
213
214
215
216
217
218
219
220
221
222
223
224
        real4 blockCenterX = sortedBlockCenter[block1];
        real4 blockSizeX = sortedBlockBoundingBox[block1];
        int neighborsInBuffer = 0;
        real3 pos1 = trimTo3(posq[x*TILE_SIZE+indexInWarp]);
#ifdef USE_PERIODIC
        const bool singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= PADDED_CUTOFF &&
                                         0.5f*periodicBoxSize.y-blockSizeX.y >= PADDED_CUTOFF &&
                                         0.5f*periodicBoxSize.z-blockSizeX.z >= PADDED_CUTOFF);
        if (singlePeriodicCopy) {
            // The box is small enough that we can just translate all the atoms into a single periodic
            // box, then skip having to apply periodic boundary conditions later.
            
225
            APPLY_PERIODIC_TO_POS_WITH_CENTER(pos1, blockCenterX)
226
227
228
        }
#endif
        posBuffer[threadIdx.x] = pos1;
229

230
231
232
233
234
        // Load exclusion data for block x.
        
        const int exclusionStart = exclusionRowIndices[x];
        const int exclusionEnd = exclusionRowIndices[x+1];
        const int numExclusions = exclusionEnd-exclusionStart;
235
        for (int j = indexInWarp; j < numExclusions; j += 32)
236
            exclusionsForX[j] = exclusionIndices[exclusionStart+j];
237
238
        if (MAX_EXCLUSIONS > 32)
            __syncthreads();
239
        
240
241
242
243
244
245
246
        // Loop over atom blocks to search for neighbors.  The threads in a warp compare block1 against 32
        // other blocks in parallel.

        for (int block2Base = block1+1; block2Base < NUM_BLOCKS; block2Base += 32) {
            int block2 = block2Base+indexInWarp;
            bool includeBlock2 = (block2 < NUM_BLOCKS);
            if (includeBlock2) {
247
248
                real4 blockCenterY = sortedBlockCenter[block2];
                real4 blockSizeY = sortedBlockBoundingBox[block2];
249
                real4 blockDelta = blockCenterX-blockCenterY;
250
#ifdef USE_PERIODIC
251
                APPLY_PERIODIC_TO_DELTA(blockDelta)
252
#endif
253
                includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < (PADDED_CUTOFF+blockCenterX.w+blockCenterY.w)*(PADDED_CUTOFF+blockCenterX.w+blockCenterY.w));
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
                blockDelta.x = max(0.0f, fabs(blockDelta.x)-blockSizeX.x-blockSizeY.x);
                blockDelta.y = max(0.0f, fabs(blockDelta.y)-blockSizeX.y-blockSizeY.y);
                blockDelta.z = max(0.0f, fabs(blockDelta.z)-blockSizeX.z-blockSizeY.z);
                includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < PADDED_CUTOFF_SQUARED);
                if (includeBlock2) {
                    unsigned short y = (unsigned short) sortedBlocks[block2].y;
                    for (int k = 0; k < numExclusions; k++)
                        includeBlock2 &= (exclusionsForX[k] != y);
                }
            }
            
            // Loop over any blocks we identified as potentially containing neighbors.
            
            int includeBlockFlags = __ballot(includeBlock2);
            while (includeBlockFlags != 0) {
                int i = __ffs(includeBlockFlags)-1;
                includeBlockFlags &= includeBlockFlags-1;
                unsigned short y = (unsigned short) sortedBlocks[block2Base+i].y;
272

273
                // Check each atom in block Y for interactions.
274

275
                int atom2 = y*TILE_SIZE+indexInWarp;
276
277
278
                real3 pos2 = trimTo3(posq[atom2]);
#ifdef USE_PERIODIC
                if (singlePeriodicCopy) {
279
                    APPLY_PERIODIC_TO_POS_WITH_CENTER(pos2, blockCenterX)
280
281
                }
#endif
282
283
284
285
286
287
                real4 blockCenterY = sortedBlockCenter[block2Base+i];
                real3 atomDelta = posBuffer[warpStart+indexInWarp]-trimTo3(blockCenterY);
#ifdef USE_PERIODIC
                APPLY_PERIODIC_TO_DELTA(atomDelta)
#endif
                int atomFlags = ballot(atomDelta.x*atomDelta.x+atomDelta.y*atomDelta.y+atomDelta.z*atomDelta.z < (PADDED_CUTOFF+blockCenterY.w)*(PADDED_CUTOFF+blockCenterY.w));
288
                int interacts = 0;
289
290
291
                if (atom2 < NUM_ATOMS && atomFlags != 0) {
                    int first = __ffs(atomFlags)-1;
                    int last = 32-__clz(atomFlags);
292
293
#ifdef USE_PERIODIC
                    if (!singlePeriodicCopy) {
294
                        for (int j = first; j < last; j++) {
295
                            real3 delta = pos2-posBuffer[warpStart+j];
296
                            APPLY_PERIODIC_TO_DELTA(delta)
297
                            interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED ? 1<<j : 0);
298
299
300
301
                        }
                    }
                    else {
#endif
302
                        for (int j = first; j < last; j++) {
303
                            real3 delta = pos2-posBuffer[warpStart+j];
304
                            interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED ? 1<<j : 0);
305
306
307
308
309
310
311
312
313
                        }
#ifdef USE_PERIODIC
                    }
#endif
                }
                
                // Add any interacting atoms to the buffer.
                
                int includeAtomFlags = __ballot(interacts);
314
315
316
317
318
                if (interacts) {
                    int index = neighborsInBuffer+__popc(includeAtomFlags&warpMask);
                    buffer[index] = atom2;
                    flagsBuffer[index] = interacts;
                }
319
320
321
322
                neighborsInBuffer += __popc(includeAtomFlags);
                if (neighborsInBuffer > BUFFER_SIZE-TILE_SIZE) {
                    // Store the new tiles to memory.
                    
323
                    neighborsInBuffer = saveSinglePairs(x, buffer, flagsBuffer, neighborsInBuffer, maxSinglePairs, singlePairCount, singlePairs, sumBuffer+warpStart, pairStartIndex);
324
                    int tilesToStore = neighborsInBuffer/TILE_SIZE;
325
326
327
328
329
330
331
332
333
334
335
336
                    if (tilesToStore > 0) {
                        if (indexInWarp == 0)
                            tileStartIndex = atomicAdd(interactionCount, tilesToStore);
                        int newTileStartIndex = tileStartIndex;
                        if (newTileStartIndex+tilesToStore <= maxTiles) {
                            if (indexInWarp < tilesToStore)
                                interactingTiles[newTileStartIndex+indexInWarp] = x;
                            for (int j = 0; j < tilesToStore; j++)
                                interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = buffer[indexInWarp+j*TILE_SIZE];
                        }
                        buffer[indexInWarp] = buffer[indexInWarp+TILE_SIZE*tilesToStore];
                        neighborsInBuffer -= TILE_SIZE*tilesToStore;
337
338
                    }
                }
339
            }
340
341
342
343
        }
        
        // If we have a partially filled buffer,  store it to memory.
        
344
        if (neighborsInBuffer > 32)
345
            neighborsInBuffer = saveSinglePairs(x, buffer, flagsBuffer, neighborsInBuffer, maxSinglePairs, singlePairCount, singlePairs, sumBuffer+warpStart, pairStartIndex);
346
347
348
349
350
        if (neighborsInBuffer > 0) {
            int tilesToStore = (neighborsInBuffer+TILE_SIZE-1)/TILE_SIZE;
            if (indexInWarp == 0)
                tileStartIndex = atomicAdd(interactionCount, tilesToStore);
            int newTileStartIndex = tileStartIndex;
351
            if (newTileStartIndex+tilesToStore <= maxTiles) {
352
353
354
355
                if (indexInWarp < tilesToStore)
                    interactingTiles[newTileStartIndex+indexInWarp] = x;
                for (int j = 0; j < tilesToStore; j++)
                    interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = (indexInWarp+j*TILE_SIZE < neighborsInBuffer ? buffer[indexInWarp+j*TILE_SIZE] : NUM_ATOMS);
356
357
358
            }
        }
    }
359
360
361
362
363
    
    // Record the positions the neighbor list is based on.
    
    for (int i = threadIdx.x+blockIdx.x*blockDim.x; i < NUM_ATOMS; i += blockDim.x*gridDim.x)
        oldPositions[i] = posq[i];
364
}