findInteractingBlocks.cl 12.6 KB
Newer Older
1
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
2
#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
peastman's avatar
peastman committed
3
#define BUFFER_SIZE 256
4
5
6
7

/**
 * Find a bounding box for the atoms in each block.
 */
8
9
10
__kernel void findBlockBounds(int numAtoms, real4 periodicBoxSize, real4 invPeriodicBoxSize, __global const real4* restrict posq,
        __global real4* restrict blockCenter, __global real4* restrict blockBoundingBox, __global int* restrict rebuildNeighborList,
        __global real2* restrict sortedBlocks) {
11
    int index = get_global_id(0);
12
    int base = index*TILE_SIZE;
13
    while (base < numAtoms) {
14
        real4 pos = posq[base];
15
#ifdef USE_PERIODIC
16
        pos.xyz -= floor(pos.xyz*invPeriodicBoxSize.xyz)*periodicBoxSize.xyz;
17
#endif
18
19
        real4 minPos = pos;
        real4 maxPos = pos;
20
        int last = min(base+TILE_SIZE, numAtoms);
21
22
23
        for (int i = base+1; i < last; i++) {
            pos = posq[i];
#ifdef USE_PERIODIC
24
25
            real4 center = 0.5f*(maxPos+minPos);
            pos.xyz -= floor((pos.xyz-center.xyz)*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
26
27
28
29
#endif
            minPos = min(minPos, pos);
            maxPos = max(maxPos, pos);
        }
30
31
        real4 blockSize = 0.5f*(maxPos-minPos);
        blockBoundingBox[index] = blockSize;
32
        blockCenter[index] = 0.5f*(maxPos+minPos);
33
        sortedBlocks[index] = (real2) (blockSize.x+blockSize.y+blockSize.z, index);
34
        index += get_global_size(0);
35
        base = index*TILE_SIZE;
36
    }
37
    if (get_global_id(0) == 0)
38
        rebuildNeighborList[0] = 0;
39
40
41
}

/**
42
 * Sort the data about bounding boxes so it can be accessed more efficiently in the next kernel.
43
 */
44
45
46
47
48
49
50
51
52
53
54
__kernel void sortBoxData(__global const real2* restrict sortedBlock, __global const real4* restrict blockCenter,
        __global const real4* restrict blockBoundingBox, __global real4* restrict sortedBlockCenter,
        __global real4* restrict sortedBlockBoundingBox, __global const real4* restrict posq, __global const real4* restrict oldPositions,
        __global unsigned int* restrict interactionCount, __global int* restrict rebuildNeighborList) {
    for (int i = get_global_id(0); i < NUM_BLOCKS; i += get_global_size(0)) {
        int index = (int) sortedBlock[i].y;
        sortedBlockCenter[i] = blockCenter[index];
        sortedBlockBoundingBox[i] = blockBoundingBox[index];
    }
    
    // Also check whether any atom has moved enough so that we really need to rebuild the neighbor list.
55

56
57
58
59
60
61
62
63
64
65
66
67
    bool rebuild = false;
    for (int i = get_global_id(0); i < NUM_ATOMS; i += get_global_size(0)) {
        real4 delta = oldPositions[i]-posq[i];
        if (delta.x*delta.x + delta.y*delta.y + delta.z*delta.z > 0.25f*PADDING*PADDING)
            rebuild = true;
    }
    if (rebuild) {
        rebuildNeighborList[0] = 1;
        interactionCount[0] = 0;
    }
}

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
__kernel void findBlocksWithInteractions(real4 periodicBoxSize, real4 invPeriodicBoxSize, __global unsigned int* restrict interactionCount,
        __global int* restrict interactingTiles, __global unsigned int* restrict interactingAtoms, __global const real4* restrict posq, unsigned int maxTiles, unsigned int startBlockIndex,
        unsigned int numBlocks, __global real2* restrict sortedBlocks, __global const real4* restrict sortedBlockCenter, __global const real4* restrict sortedBlockBoundingBox,
        __global const unsigned int* restrict exclusionIndices, __global const unsigned int* restrict exclusionRowIndices, __global real4* restrict oldPositions,
        __global const int* restrict rebuildNeighborList) {

    if (rebuildNeighborList[0] == 0)
        return; // The neighbor list doesn't need to be rebuilt.

    const int indexInWarp = get_local_id(0)%32;
    const int warpStart = get_local_id(0)-indexInWarp;
    const int totalWarps = get_global_size(0)/32;
    const int warpIndex = get_global_id(0)/32;
    const int warpMask = (1<<indexInWarp)-1;
    __local int workgroupBuffer[BUFFER_SIZE*(GROUP_SIZE/32)];
    __local int warpExclusions[MAX_EXCLUSIONS*(GROUP_SIZE/32)];
    __local real3 posBuffer[GROUP_SIZE];
    __local volatile int workgroupTileIndex[GROUP_SIZE/32];
    __local bool includeBlockFlags[GROUP_SIZE];
    __local short2 atomCountBuffer[GROUP_SIZE];
    __local int* buffer = workgroupBuffer+BUFFER_SIZE*(warpStart/32);
    __local int* exclusionsForX = warpExclusions+MAX_EXCLUSIONS*(warpStart/32);
    __local volatile int* tileStartIndex = workgroupTileIndex+(warpStart/32);

    // Loop over blocks.
    
    for (int block1 = startBlockIndex+warpIndex; block1 < startBlockIndex+numBlocks; block1 += totalWarps) {
        // Load data for this block.  Note that all threads in a warp are processing the same block.
        
        real2 sortedKey = sortedBlocks[block1];
        int x = (int) sortedKey.y;
        real4 blockCenterX = sortedBlockCenter[block1];
        real4 blockSizeX = sortedBlockBoundingBox[block1];
        int neighborsInBuffer = 0;
        real3 pos1 = posq[x*TILE_SIZE+indexInWarp].xyz;
#ifdef USE_PERIODIC
        const bool singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= PADDED_CUTOFF &&
                                         0.5f*periodicBoxSize.y-blockSizeX.y >= PADDED_CUTOFF &&
                                         0.5f*periodicBoxSize.z-blockSizeX.z >= PADDED_CUTOFF);
        if (singlePeriodicCopy) {
            // The box is small enough that we can just translate all the atoms into a single periodic
            // box, then skip having to apply periodic boundary conditions later.
            
            pos1.xyz -= floor((pos1.xyz-blockCenterX.xyz)*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
        }
#endif
        posBuffer[get_local_id(0)] = pos1;

        // Load exclusion data for block x.
        
        const int exclusionStart = exclusionRowIndices[x];
        const int exclusionEnd = exclusionRowIndices[x+1];
        const int numExclusions = exclusionEnd-exclusionStart;
        for (int j = indexInWarp; j < numExclusions; j += 32)
            exclusionsForX[j] = exclusionIndices[exclusionStart+j];
        if (MAX_EXCLUSIONS > 32)
            barrier(CLK_LOCAL_MEM_FENCE);
        else
            SYNC_WARPS;
        
        // Loop over atom blocks to search for neighbors.  The threads in a warp compare block1 against 32
        // other blocks in parallel.

        for (int block2Base = block1+1; block2Base < NUM_BLOCKS; block2Base += 32) {
            int block2 = block2Base+indexInWarp;
            bool includeBlock2 = (block2 < NUM_BLOCKS);
            if (includeBlock2) {
                real4 blockCenterY = (block2 < NUM_BLOCKS ? sortedBlockCenter[block2] : (real4) (0));
                real4 blockSizeY = (block2 < NUM_BLOCKS ? sortedBlockBoundingBox[block2] : (real4) (0));
                real4 blockDelta = blockCenterX-blockCenterY;
#ifdef USE_PERIODIC
                blockDelta.xyz -= floor(blockDelta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
#endif
                blockDelta.x = max(0.0f, fabs(blockDelta.x)-blockSizeX.x-blockSizeY.x);
                blockDelta.y = max(0.0f, fabs(blockDelta.y)-blockSizeX.y-blockSizeY.y);
                blockDelta.z = max(0.0f, fabs(blockDelta.z)-blockSizeX.z-blockSizeY.z);
                includeBlock2 &= (blockDelta.x*blockDelta.x+blockDelta.y*blockDelta.y+blockDelta.z*blockDelta.z < PADDED_CUTOFF_SQUARED);
                if (includeBlock2) {
                    unsigned short y = (unsigned short) sortedBlocks[block2].y;
                    for (int k = 0; k < numExclusions; k++)
                        includeBlock2 &= (exclusionsForX[k] != y);
                }
            }
            
            // Loop over any blocks we identified as potentially containing neighbors.
            
            includeBlockFlags[get_local_id(0)] = includeBlock2;
            SYNC_WARPS;
            for (int i = 0; i < TILE_SIZE; i++) {
157
158
159
                while (i < TILE_SIZE && !includeBlockFlags[warpStart+i])
                    i++;
                if (i < TILE_SIZE) {
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
                    unsigned short y = (unsigned short) sortedBlocks[block2Base+i].y;

                    // Check each atom in block Y for interactions.

                    int start = y*TILE_SIZE;
                    int atom2 = start+indexInWarp;
                    real3 pos2 = posq[atom2].xyz;
#ifdef USE_PERIODIC
                    if (singlePeriodicCopy)
                        pos2.xyz -= floor((pos2.xyz-blockCenterX.xyz)*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
#endif
                    bool interacts = false;
                    if (atom2 < NUM_ATOMS) {
#ifdef USE_PERIODIC
                        if (!singlePeriodicCopy) {
                            for (int j = 0; j < TILE_SIZE; j++) {
                                real3 delta = pos2-posBuffer[warpStart+j];
                                delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
                                interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED);
                            }
                        }
                        else {
#endif
                            for (int j = 0; j < TILE_SIZE; j++) {
                                real3 delta = pos2-posBuffer[warpStart+j];
                                interacts |= (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < PADDED_CUTOFF_SQUARED);
                            }
#ifdef USE_PERIODIC
                        }
#endif
                    }
                    
                    // Do a prefix sum to compact the list of atoms.

                    atomCountBuffer[get_local_id(0)].x = (interacts ? 1 : 0);
                    SYNC_WARPS;
                    int whichBuffer = 0;
                    for (int offset = 1; offset < TILE_SIZE; offset *= 2) {
                        if (whichBuffer == 0)
peastman's avatar
peastman committed
199
                            atomCountBuffer[get_local_id(0)].y = (indexInWarp < offset ? atomCountBuffer[get_local_id(0)].x : atomCountBuffer[get_local_id(0)].x+atomCountBuffer[get_local_id(0)-offset].x);
200
                        else
peastman's avatar
peastman committed
201
                            atomCountBuffer[get_local_id(0)].x = (indexInWarp < offset ? atomCountBuffer[get_local_id(0)].y : atomCountBuffer[get_local_id(0)].y+atomCountBuffer[get_local_id(0)-offset].y);
202
203
204
205
206
207
208
                        whichBuffer = 1-whichBuffer;
                        SYNC_WARPS;
                    }
                    
                    // Add any interacting atoms to the buffer.

                    if (interacts)
peastman's avatar
peastman committed
209
210
                        buffer[neighborsInBuffer+atomCountBuffer[get_local_id(0)].y-1] = atom2;
                    neighborsInBuffer += atomCountBuffer[warpStart+TILE_SIZE-1].y;
211
212
213
214
215
216
                    if (neighborsInBuffer > BUFFER_SIZE-TILE_SIZE) {
                        // Store the new tiles to memory.

                        int tilesToStore = neighborsInBuffer/TILE_SIZE;
                        if (indexInWarp == 0)
                            *tileStartIndex = atom_add(interactionCount, tilesToStore);
peastman's avatar
peastman committed
217
                        SYNC_WARPS;
218
219
220
221
222
223
224
225
226
                        int newTileStartIndex = *tileStartIndex;
                        if (newTileStartIndex+tilesToStore < maxTiles) {
                            if (indexInWarp < tilesToStore)
                                interactingTiles[newTileStartIndex+indexInWarp] = x;
                            for (int j = 0; j < tilesToStore; j++)
                                interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = buffer[indexInWarp+j*TILE_SIZE];
                        }
                        buffer[indexInWarp] = buffer[indexInWarp+TILE_SIZE*tilesToStore];
                        neighborsInBuffer -= TILE_SIZE*tilesToStore;
peastman's avatar
peastman committed
227
                   }
228
229
230
231
232
233
234
235
236
237
                }
            }
        }
        
        // If we have a partially filled buffer,  store it to memory.
        
        if (neighborsInBuffer > 0) {
            int tilesToStore = (neighborsInBuffer+TILE_SIZE-1)/TILE_SIZE;
            if (indexInWarp == 0)
                *tileStartIndex = atom_add(interactionCount, tilesToStore);
peastman's avatar
peastman committed
238
            SYNC_WARPS;
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
            int newTileStartIndex = *tileStartIndex;
            if (newTileStartIndex+tilesToStore < maxTiles) {
                if (indexInWarp < tilesToStore)
                    interactingTiles[newTileStartIndex+indexInWarp] = x;
                for (int j = 0; j < tilesToStore; j++)
                    interactingAtoms[(newTileStartIndex+j)*TILE_SIZE+indexInWarp] = (indexInWarp+j*TILE_SIZE < neighborsInBuffer ? buffer[indexInWarp+j*TILE_SIZE] : NUM_ATOMS);
            }
        }
    }
    
    // Record the positions the neighbor list is based on.
    
    for (int i = get_global_id(0); i < NUM_ATOMS; i += get_global_size(0))
        oldPositions[i] = posq[i];
}