findInteractingBlocks.cl 14 KB
Newer Older
1
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
2
#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
3
#define TILE_SIZE 32
4
5
6
#define GROUP_SIZE 64
#define BUFFER_GROUPS 4
#define BUFFER_SIZE BUFFER_GROUPS*GROUP_SIZE
7
8
9
10

/**
 * Find a bounding box for the atoms in each block.
 */
11
__kernel void findBlockBounds(int numAtoms, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global const float4* restrict posq, __global float4* restrict blockCenter, __global float4* restrict blockBoundingBox, __global unsigned int* restrict interactionCount) {
12
    int index = get_global_id(0);
13
    int base = index*TILE_SIZE;
14
15
16
    while (base < numAtoms) {
        float4 pos = posq[base];
#ifdef USE_PERIODIC
17
18
19
        pos.x -= floor(pos.x*invPeriodicBoxSize.x)*periodicBoxSize.x;
        pos.y -= floor(pos.y*invPeriodicBoxSize.y)*periodicBoxSize.y;
        pos.z -= floor(pos.z*invPeriodicBoxSize.z)*periodicBoxSize.z;
20
21
22
23
        float4 firstPoint = pos;
#endif
        float4 minPos = pos;
        float4 maxPos = pos;
24
        int last = min(base+TILE_SIZE, numAtoms);
25
26
27
        for (int i = base+1; i < last; i++) {
            pos = posq[i];
#ifdef USE_PERIODIC
28
29
30
            pos.x -= floor((pos.x-firstPoint.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
            pos.y -= floor((pos.y-firstPoint.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
            pos.z -= floor((pos.z-firstPoint.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
31
32
33
34
35
36
37
#endif
            minPos = min(minPos, pos);
            maxPos = max(maxPos, pos);
        }
        blockBoundingBox[index] = 0.5f*(maxPos-minPos);
        blockCenter[index] = 0.5f*(maxPos+minPos);
        index += get_global_size(0);
38
        base = index*TILE_SIZE;
39
    }
40
41
42
43
44
45
46
47
    if (get_global_id(0) == 0)
        interactionCount[0] = 0;
}

/**
 * This is called by findBlocksWithInteractions().  It compacts the list of blocks and writes them
 * to global memory.
 */
48
49
void storeInteractionData(__local ushort2* buffer, __local int* valid, __local short* sum, __local ushort2* temp, __local int* baseIndex,
            __global unsigned int* interactionCount, __global ushort2* interactingTiles, float cutoffSquared, float4 periodicBoxSize,
50
            float4 invPeriodicBoxSize, __global const float4* posq, __global const float4* blockCenter, __global const float4* blockBoundingBox, unsigned int maxTiles) {
51
52
53
    // The buffer is full, so we need to compact it and write out results.  Start by doing a parallel prefix sum.

    for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
54
        temp[i].x = (valid[i] ? 1 : 0);
55
56
57
58
59
    barrier(CLK_LOCAL_MEM_FENCE);
    int whichBuffer = 0;
    for (int offset = 1; offset < BUFFER_SIZE; offset *= 2) {
        if (whichBuffer == 0)
            for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
60
                temp[i].y = (i < offset ? temp[i].x : temp[i].x+temp[i-offset].x);
61
62
        else
            for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
63
                temp[i].x = (i < offset ? temp[i].y : temp[i].y+temp[i-offset].y);
64
65
66
        whichBuffer = 1-whichBuffer;
        barrier(CLK_LOCAL_MEM_FENCE);
    }
67
    if (whichBuffer == 0)
68
        for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
69
70
71
72
73
            sum[i] = temp[i].x;
    else
        for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
            sum[i] = temp[i].y;
    barrier(CLK_LOCAL_MEM_FENCE);
74
75
    int numValid = sum[BUFFER_SIZE-1];
    barrier(CLK_LOCAL_MEM_FENCE);
76

77
    // Compact the buffer.
78
79
80
81

    for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
        if (valid[i]) {
            temp[sum[i]-1] = buffer[i];
82
            sum[i] = valid[i];
83
            valid[i] = false;
84
            buffer[i] = (ushort2) 1;
85
86
        }
    barrier(CLK_LOCAL_MEM_FENCE);
87

88
#ifndef WARPS_ARE_ATOMIC
89
    // Filter the list of tiles by comparing the distance from each atom to the other bounding box.
90
    // We only do this if we aren't already optimizing the computation using flags.
91
92
93
94

    int index = get_local_id(0)&(TILE_SIZE-1);
    int group = get_local_id(0)/TILE_SIZE;
    float4 center, boxSize, pos;
95
    for (int tile = 0; tile < numValid; tile++) {
96
97
        int x = temp[tile].x;
        int y = temp[tile].y;
98
        if (x == y)
99
100
101
102
            continue;

        // Load an atom position and the bounding box the other block.

103
104
105
106
107
108
109
110
111
112
        int box = (group == 0 ? x : y);
        int atom = (group == 0 ? y : x)*TILE_SIZE+index;
#ifdef MAC_AMD_WORKAROUND
        __global float* bc = (__global float*) blockCenter;
        __global float* bb = (__global float*) blockBoundingBox;
        __global float* ps = (__global float*) posq;
        center = (float4) (bc[4*box], bc[4*box+1], bc[4*box+2], 0.0f);
        boxSize = (float4) (bb[4*box], bb[4*box+1], bb[4*box+2], 0.0f);
        pos = (float4) (ps[4*atom], ps[4*atom+1], ps[4*atom+2], 0.0f);
#else
113
114
115
        center = blockCenter[(group == 0 ? x : y)];
        boxSize = blockBoundingBox[(group == 0 ? x : y)];
        pos = posq[(group == 0 ? y : x)*TILE_SIZE+index];
116
#endif
117
118
119
120
121
122
123
124
125
126

        // Find the distance of the atom from the bounding box.

        float4 delta = pos-center;
#ifdef USE_PERIODIC
        delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
        delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
        delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
        delta = max((float4) 0.0f, fabs(delta)-boxSize);
127
        __local ushort* flag = (__local ushort*) &buffer[tile];
128
129
130
131
132
133
134
135
136
        if (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < cutoffSquared)
            flag[group] = false;
        barrier(CLK_LOCAL_MEM_FENCE);
        if (flag[0] || flag[1]) {
            // This tile contains no interactions.

            numValid--;
            if (get_local_id(0) == 0)
                temp[tile] = temp[numValid];
137
            tile--;
138
139
140
        }
        barrier(CLK_LOCAL_MEM_FENCE);
    }
141
#endif
142
143
144

    // Store it to global memory.

145
146
147
    if (get_local_id(0) == 0)
        *baseIndex = atom_add(interactionCount, numValid);
    barrier(CLK_LOCAL_MEM_FENCE);
148
    if (*baseIndex+numValid <= maxTiles)
149
150
        for (int i = get_local_id(0); i < numValid; i += GROUP_SIZE)
            interactingTiles[*baseIndex+i] = temp[i];
151
    barrier(CLK_LOCAL_MEM_FENCE);
152
153
154
155
156
157
}

/**
 * Compare the bounding boxes for each pair of blocks.  If they are sufficiently far apart,
 * mark them as non-interacting.
 */
158
159
160
__kernel void findBlocksWithInteractions(float cutoffSquared, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global const float4* restrict blockCenter,
        __global const float4* restrict blockBoundingBox, __global unsigned int* restrict interactionCount, __global ushort2* restrict interactingTiles,
        __global unsigned int* restrict interactionFlags, __global const float4* restrict posq, unsigned int maxTiles, unsigned int startTileIndex,
161
        unsigned int endTileIndex) {
162
163
164
165
    __local ushort2 buffer[BUFFER_SIZE];
    __local int valid[BUFFER_SIZE];
    __local short sum[BUFFER_SIZE];
    __local ushort2 temp[BUFFER_SIZE];
166
167
    __local int bufferFull;
    __local int globalIndex;
168
169
170
171
172
173
174
175
#ifdef AMD_ATOMIC_WORK_AROUND
    // Do a byte write to force all memory accesses to interactionCount to use the complete path.
    // This avoids the atomic access from causing all word accesses to other buffers from using the slow complete path.
    // The IF actually causes the write to never be executed, its presence is all that is needed.
    // AMD APP SDK 2.4 has this problem.
    if (get_global_id(0) == get_local_id(0)+1)
        ((__global char*)interactionCount)[sizeof(unsigned int)+1] = 0;
#endif
176
177
178
179
180
181
    int valuesInBuffer = 0;
    if (get_local_id(0) == 0)
        bufferFull = false;
    for (int i = 0; i < BUFFER_GROUPS; ++i)
        valid[i*GROUP_SIZE+get_local_id(0)] = false;
    barrier(CLK_LOCAL_MEM_FENCE);
182
    for (int baseIndex = startTileIndex+get_group_id(0)*get_local_size(0); baseIndex < endTileIndex; baseIndex += get_global_size(0)) {
183
        // Identify the pair of blocks to compare.
184

185
        int index = baseIndex+get_local_id(0);
186
        if (index < endTileIndex) {
187
188
            unsigned int y = (unsigned int) floor(NUM_BLOCKS+0.5f-sqrt((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*index));
            unsigned int x = (index-y*NUM_BLOCKS+y*(y+1)/2);
189
190
            if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                y += (x < y ? -1 : 1);
191
192
                x = (index-y*NUM_BLOCKS+y*(y+1)/2);
            }
193

194
            // Find the distance between the bounding boxes of the two cells.
195

196
197
198
199
200
201
202
203
204
#ifdef MAC_AMD_WORKAROUND
            __global float* bc = (__global float*) blockCenter;
            __global float* bb = (__global float*) blockBoundingBox;
            float4 bcx = (float4) (bc[4*x], bc[4*x+1], bc[4*x+2], 0.0f);
            float4 bcy = (float4) (bc[4*y], bc[4*y+1], bc[4*y+2], 0.0f);
            float4 delta = bcx-bcy;
            float4 boxSizea = (float4) (bb[4*x], bb[4*x+1], bb[4*x+2], 0.0f);
            float4 boxSizeb = (float4) (bb[4*y], bb[4*y+1], bb[4*y+2], 0.0f);
#else
205
            float4 delta = blockCenter[x]-blockCenter[y];
206
207
208
            float4 boxSizea = blockBoundingBox[x];
            float4 boxSizeb = blockBoundingBox[y];
#endif
209
#ifdef USE_PERIODIC
210
211
212
            delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
            delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
            delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
213
#endif
214
215
216
217
218
219
220
221
            delta.x = max(0.0f, fabs(delta.x)-boxSizea.x-boxSizeb.x);
            delta.y = max(0.0f, fabs(delta.y)-boxSizea.y-boxSizeb.y);
            delta.z = max(0.0f, fabs(delta.z)-boxSizea.z-boxSizeb.z);
            if (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < cutoffSquared) {
                // Add this tile to the buffer.

                int bufferIndex = valuesInBuffer*GROUP_SIZE+get_local_id(0);
                valid[bufferIndex] = true;
222
                buffer[bufferIndex] = (ushort2) (x, y);
223
224
225
226
227
228
229
                valuesInBuffer++;
                if (!bufferFull && valuesInBuffer == BUFFER_GROUPS)
                    bufferFull = true;
            }
        }
        barrier(CLK_LOCAL_MEM_FENCE);
        if (bufferFull) {
230
            storeInteractionData(buffer, valid, sum, temp, &globalIndex, interactionCount, interactingTiles, cutoffSquared, periodicBoxSize, invPeriodicBoxSize, posq, blockCenter, blockBoundingBox, maxTiles);
231
232
233
234
235
            valuesInBuffer = 0;
            if (get_local_id(0) == 0)
                bufferFull = false;
            barrier(CLK_LOCAL_MEM_FENCE);
        }
236
    }
237
    storeInteractionData(buffer, valid, sum, temp, &globalIndex, interactionCount, interactingTiles, cutoffSquared, periodicBoxSize, invPeriodicBoxSize, posq, blockCenter, blockBoundingBox, maxTiles);
238
239
240
241
242
243
}

/**
 * Compare each atom in one block to the bounding box of another block, and set
 * flags for which ones are interacting.
 */
244
__kernel void findInteractionsWithinBlocks(float cutoffSquared, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global const float4* restrict posq, __global const ushort2* restrict tiles, __global const float4* restrict blockCenter,
245
            __global const float4* restrict blockBoundingBox, __global unsigned int* restrict interactionFlags, __global const unsigned int* restrict interactionCount, __local volatile unsigned int* restrict flags, unsigned int maxTiles) {
246
247
    unsigned int totalWarps = get_global_size(0)/TILE_SIZE;
    unsigned int warp = get_global_id(0)/TILE_SIZE;
248
249
250
    unsigned int numTiles = interactionCount[0];
    unsigned int pos = warp*numTiles/totalWarps;
    unsigned int end = (warp+1)*numTiles/totalWarps;
251
    unsigned int index = get_local_id(0) & (TILE_SIZE - 1);
252

253
    if (numTiles > maxTiles)
254
        return;
255
256
    unsigned int lasty = 0xFFFFFFFF;
    float4 apos;
257
258
    while (pos < end) {
        // Extract the coordinates of this tile
259
260
261
262
        ushort2 tileIndices = tiles[pos];
        unsigned int x = tileIndices.x;
        unsigned int y = tileIndices.y;
        if (x == y) {
263
264
265
            if (index == 0)
                interactionFlags[pos] = 0xFFFFFFFF;
        }
266
        else {
267
268
269
270
271
            // Load the bounding box for x and the atom positions for y.

            float4 center = blockCenter[x];
            float4 boxSize = blockBoundingBox[x];
            if (y != lasty)
272
                apos = posq[y*TILE_SIZE+index];
273
274
275
276
277

            // Find the distance of the atom from the bounding box.

            float4 delta = apos-center;
#ifdef USE_PERIODIC
278
279
280
                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
281
282
283
284
285
286
287
#endif
            delta = max((float4) 0.0f, fabs(delta)-boxSize);
            int thread = get_local_id(0);
            flags[thread] = (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z > cutoffSquared ? 0 : 1 << index);

            // Sum the flags.

288
#ifdef WARPS_ARE_ATOMIC
289
            if (index % 4 == 0)
290
                flags[thread] += flags[thread+1]+flags[thread+2]+flags[thread+3];
291
292
293
#else
            barrier(CLK_LOCAL_MEM_FENCE);
            if (index % 4 == 0)
294
                flags[thread] += flags[thread+1]+flags[thread+2]+flags[thread+3];
295
296
            barrier(CLK_LOCAL_MEM_FENCE);
#endif
297
            if (index == 0) {
298
                unsigned int allFlags = flags[thread]+flags[thread+4]+flags[thread+8]+flags[thread+12]+flags[thread+16]+flags[thread+20]+flags[thread+24]+flags[thread+28];
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

                // Count how many flags are set, and based on that decide whether to compute all interactions
                // or only a fraction of them.

                unsigned int bits = (allFlags&0x55555555) + ((allFlags>>1)&0x55555555);
                bits = (bits&0x33333333) + ((bits>>2)&0x33333333);
                bits = (bits&0x0F0F0F0F) + ((bits>>4)&0x0F0F0F0F);
                bits = (bits&0x00FF00FF) + ((bits>>8)&0x00FF00FF);
                bits = (bits&0x0000FFFF) + ((bits>>16)&0x0000FFFF);
                interactionFlags[pos] = (bits > 12 ? 0xFFFFFFFF : allFlags);
            }
            lasty = y;
        }
        pos++;
    }
}