"platforms/vscode:/vscode.git/clone" did not exist on "4d827cd9ded61f0378a15e96ddc29eb6cc6012fe"
findInteractingBlocks.cl 12.8 KB
Newer Older
1
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
2
#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
3
#define TILE_SIZE 32
4
5
6
#define GROUP_SIZE 64
#define BUFFER_GROUPS 4
#define BUFFER_SIZE BUFFER_GROUPS*GROUP_SIZE
7
8
9
10

/**
 * Find a bounding box for the atoms in each block.
 */
11
__kernel void findBlockBounds(int numAtoms, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global float4* posq, __global float4* blockCenter, __global float4* blockBoundingBox, __global unsigned int* interactionCount) {
12
    int index = get_global_id(0);
13
    int base = index*TILE_SIZE;
14
15
16
    while (base < numAtoms) {
        float4 pos = posq[base];
#ifdef USE_PERIODIC
17
18
19
        pos.x -= floor(pos.x*invPeriodicBoxSize.x)*periodicBoxSize.x;
        pos.y -= floor(pos.y*invPeriodicBoxSize.y)*periodicBoxSize.y;
        pos.z -= floor(pos.z*invPeriodicBoxSize.z)*periodicBoxSize.z;
20
21
22
23
        float4 firstPoint = pos;
#endif
        float4 minPos = pos;
        float4 maxPos = pos;
24
        int last = min(base+TILE_SIZE, numAtoms);
25
26
27
        for (int i = base+1; i < last; i++) {
            pos = posq[i];
#ifdef USE_PERIODIC
28
29
30
            pos.x -= floor((pos.x-firstPoint.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
            pos.y -= floor((pos.y-firstPoint.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
            pos.z -= floor((pos.z-firstPoint.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
31
32
33
34
35
36
37
#endif
            minPos = min(minPos, pos);
            maxPos = max(maxPos, pos);
        }
        blockBoundingBox[index] = 0.5f*(maxPos-minPos);
        blockCenter[index] = 0.5f*(maxPos+minPos);
        index += get_global_size(0);
38
        base = index*TILE_SIZE;
39
    }
40
41
42
43
44
45
46
47
    if (get_global_id(0) == 0)
        interactionCount[0] = 0;
}

/**
 * This is called by findBlocksWithInteractions().  It compacts the list of blocks and writes them
 * to global memory.
 */
48
49
void storeInteractionData(__local ushort2* buffer, __local int* valid, __local short* sum, __local ushort2* temp, __local int* baseIndex,
            __global unsigned int* interactionCount, __global ushort2* interactingTiles, float cutoffSquared, float4 periodicBoxSize,
50
            float4 invPeriodicBoxSize, __global float4* posq, __global float4* blockCenter, __global float4* blockBoundingBox, unsigned int maxTiles) {
51
52
53
    // The buffer is full, so we need to compact it and write out results.  Start by doing a parallel prefix sum.

    for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
54
        temp[i].x = (valid[i] ? 1 : 0);
55
56
57
58
59
    barrier(CLK_LOCAL_MEM_FENCE);
    int whichBuffer = 0;
    for (int offset = 1; offset < BUFFER_SIZE; offset *= 2) {
        if (whichBuffer == 0)
            for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
60
                temp[i].y = (i < offset ? temp[i].x : temp[i].x+temp[i-offset].x);
61
62
        else
            for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
63
                temp[i].x = (i < offset ? temp[i].y : temp[i].y+temp[i-offset].y);
64
65
66
        whichBuffer = 1-whichBuffer;
        barrier(CLK_LOCAL_MEM_FENCE);
    }
67
    if (whichBuffer == 0)
68
        for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
69
70
71
72
73
            sum[i] = temp[i].x;
    else
        for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
            sum[i] = temp[i].y;
    barrier(CLK_LOCAL_MEM_FENCE);
74
75
    int numValid = sum[BUFFER_SIZE-1];
    barrier(CLK_LOCAL_MEM_FENCE);
76

77
    // Compact the buffer.
78
79
80
81

    for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
        if (valid[i]) {
            temp[sum[i]-1] = buffer[i];
82
            sum[i] = valid[i];
83
            valid[i] = false;
84
            buffer[i] = (ushort2) 1;
85
86
        }
    barrier(CLK_LOCAL_MEM_FENCE);
87

88
#ifndef WARPS_ARE_ATOMIC
89
    // Filter the list of tiles by comparing the distance from each atom to the other bounding box.
90
    // We only do this if we aren't already optimizing the computation using flags.
91
92
93
94

    int index = get_local_id(0)&(TILE_SIZE-1);
    int group = get_local_id(0)/TILE_SIZE;
    float4 center, boxSize, pos;
95
    for (int tile = 0; tile < numValid; tile++) {
96
97
        int x = temp[tile].x;
        int y = temp[tile].y;
98
        if (x == y)
99
100
101
102
            continue;

        // Load an atom position and the bounding box the other block.

103
104
105
        center = blockCenter[(group == 0 ? x : y)];
        boxSize = blockBoundingBox[(group == 0 ? x : y)];
        pos = posq[(group == 0 ? y : x)*TILE_SIZE+index];
106
107
108
109
110
111
112
113
114
115

        // Find the distance of the atom from the bounding box.

        float4 delta = pos-center;
#ifdef USE_PERIODIC
        delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
        delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
        delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
        delta = max((float4) 0.0f, fabs(delta)-boxSize);
116
        __local ushort* flag = (__local ushort*) &buffer[tile];
117
118
119
120
121
122
123
124
125
        if (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < cutoffSquared)
            flag[group] = false;
        barrier(CLK_LOCAL_MEM_FENCE);
        if (flag[0] || flag[1]) {
            // This tile contains no interactions.

            numValid--;
            if (get_local_id(0) == 0)
                temp[tile] = temp[numValid];
126
            tile--;
127
128
129
        }
        barrier(CLK_LOCAL_MEM_FENCE);
    }
130
#endif
131
132
133

    // Store it to global memory.

134
135
136
    if (get_local_id(0) == 0)
        *baseIndex = atom_add(interactionCount, numValid);
    barrier(CLK_LOCAL_MEM_FENCE);
137
    if (*baseIndex+numValid <= maxTiles)
138
139
        for (int i = get_local_id(0); i < numValid; i += GROUP_SIZE)
            interactingTiles[*baseIndex+i] = temp[i];
140
    barrier(CLK_LOCAL_MEM_FENCE);
141
142
143
144
145
146
}

/**
 * Compare the bounding boxes for each pair of blocks.  If they are sufficiently far apart,
 * mark them as non-interacting.
 */
147
__kernel void findBlocksWithInteractions(float cutoffSquared, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global float4* blockCenter,
148
        __global float4* blockBoundingBox, __global unsigned int* interactionCount, __global ushort2* interactingTiles,
149
150
        __global unsigned int* interactionFlags, __global float4* posq, unsigned int maxTiles, unsigned int startTileIndex,
        unsigned int endTileIndex) {
151
152
153
154
    __local ushort2 buffer[BUFFER_SIZE];
    __local int valid[BUFFER_SIZE];
    __local short sum[BUFFER_SIZE];
    __local ushort2 temp[BUFFER_SIZE];
155
156
    __local int bufferFull;
    __local int globalIndex;
157
158
159
160
161
162
163
164
#ifdef AMD_ATOMIC_WORK_AROUND
    // Do a byte write to force all memory accesses to interactionCount to use the complete path.
    // This avoids the atomic access from causing all word accesses to other buffers from using the slow complete path.
    // The IF actually causes the write to never be executed, its presence is all that is needed.
    // AMD APP SDK 2.4 has this problem.
    if (get_global_id(0) == get_local_id(0)+1)
        ((__global char*)interactionCount)[sizeof(unsigned int)+1] = 0;
#endif
165
166
167
168
169
170
    int valuesInBuffer = 0;
    if (get_local_id(0) == 0)
        bufferFull = false;
    for (int i = 0; i < BUFFER_GROUPS; ++i)
        valid[i*GROUP_SIZE+get_local_id(0)] = false;
    barrier(CLK_LOCAL_MEM_FENCE);
171
    for (int baseIndex = startTileIndex+get_group_id(0)*get_local_size(0); baseIndex < endTileIndex; baseIndex += get_global_size(0)) {
172
        // Identify the pair of blocks to compare.
173

174
        int index = baseIndex+get_local_id(0);
175
        if (index < endTileIndex) {
176
177
            unsigned int y = (unsigned int) floor(NUM_BLOCKS+0.5f-sqrt((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*index));
            unsigned int x = (index-y*NUM_BLOCKS+y*(y+1)/2);
178
179
            if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                y += (x < y ? -1 : 1);
180
181
                x = (index-y*NUM_BLOCKS+y*(y+1)/2);
            }
182

183
            // Find the distance between the bounding boxes of the two cells.
184

185
            float4 delta = blockCenter[x]-blockCenter[y];
186
#ifdef USE_PERIODIC
187
188
189
            delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
            delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
            delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
190
#endif
191
192
193
194
195
196
197
198
199
200
            float4 boxSizea = blockBoundingBox[x];
            float4 boxSizeb = blockBoundingBox[y];
            delta.x = max(0.0f, fabs(delta.x)-boxSizea.x-boxSizeb.x);
            delta.y = max(0.0f, fabs(delta.y)-boxSizea.y-boxSizeb.y);
            delta.z = max(0.0f, fabs(delta.z)-boxSizea.z-boxSizeb.z);
            if (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < cutoffSquared) {
                // Add this tile to the buffer.

                int bufferIndex = valuesInBuffer*GROUP_SIZE+get_local_id(0);
                valid[bufferIndex] = true;
201
                buffer[bufferIndex] = (ushort2) (x, y);
202
203
204
205
206
207
208
                valuesInBuffer++;
                if (!bufferFull && valuesInBuffer == BUFFER_GROUPS)
                    bufferFull = true;
            }
        }
        barrier(CLK_LOCAL_MEM_FENCE);
        if (bufferFull) {
209
            storeInteractionData(buffer, valid, sum, temp, &globalIndex, interactionCount, interactingTiles, cutoffSquared, periodicBoxSize, invPeriodicBoxSize, posq, blockCenter, blockBoundingBox, maxTiles);
210
211
212
213
214
            valuesInBuffer = 0;
            if (get_local_id(0) == 0)
                bufferFull = false;
            barrier(CLK_LOCAL_MEM_FENCE);
        }
215
    }
216
    storeInteractionData(buffer, valid, sum, temp, &globalIndex, interactionCount, interactingTiles, cutoffSquared, periodicBoxSize, invPeriodicBoxSize, posq, blockCenter, blockBoundingBox, maxTiles);
217
218
219
220
221
222
}

/**
 * Compare each atom in one block to the bounding box of another block, and set
 * flags for which ones are interacting.
 */
223
__kernel void findInteractionsWithinBlocks(float cutoffSquared, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global float4* posq, __global ushort2* tiles, __global float4* blockCenter,
224
            __global float4* blockBoundingBox, __global unsigned int* interactionFlags, __global unsigned int* interactionCount, __local unsigned int* flags, unsigned int maxTiles) {
225
226
    unsigned int totalWarps = get_global_size(0)/TILE_SIZE;
    unsigned int warp = get_global_id(0)/TILE_SIZE;
227
228
229
    unsigned int numTiles = interactionCount[0];
    unsigned int pos = warp*numTiles/totalWarps;
    unsigned int end = (warp+1)*numTiles/totalWarps;
230
    unsigned int index = get_local_id(0) & (TILE_SIZE - 1);
231

232
    if (numTiles > maxTiles)
233
        return;
234
235
    unsigned int lasty = 0xFFFFFFFF;
    float4 apos;
236
237
    while (pos < end) {
        // Extract the coordinates of this tile
238
239
240
241
        ushort2 tileIndices = tiles[pos];
        unsigned int x = tileIndices.x;
        unsigned int y = tileIndices.y;
        if (x == y) {
242
243
244
            if (index == 0)
                interactionFlags[pos] = 0xFFFFFFFF;
        }
245
        else {
246
247
248
249
250
            // Load the bounding box for x and the atom positions for y.

            float4 center = blockCenter[x];
            float4 boxSize = blockBoundingBox[x];
            if (y != lasty)
251
                apos = posq[y*TILE_SIZE+index];
252
253
254
255
256

            // Find the distance of the atom from the bounding box.

            float4 delta = apos-center;
#ifdef USE_PERIODIC
257
258
259
                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
260
261
262
263
264
265
266
#endif
            delta = max((float4) 0.0f, fabs(delta)-boxSize);
            int thread = get_local_id(0);
            flags[thread] = (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z > cutoffSquared ? 0 : 1 << index);

            // Sum the flags.

267
#ifdef WARPS_ARE_ATOMIC
268
            if (index % 4 == 0)
269
                flags[thread] += flags[thread+1]+flags[thread+2]+flags[thread+3];
270
271
272
#else
            barrier(CLK_LOCAL_MEM_FENCE);
            if (index % 4 == 0)
273
                flags[thread] += flags[thread+1]+flags[thread+2]+flags[thread+3];
274
275
            barrier(CLK_LOCAL_MEM_FENCE);
#endif
276
            if (index == 0) {
277
                unsigned int allFlags = flags[thread]+flags[thread+4]+flags[thread+8]+flags[thread+12]+flags[thread+16]+flags[thread+20]+flags[thread+24]+flags[thread+28];
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

                // Count how many flags are set, and based on that decide whether to compute all interactions
                // or only a fraction of them.

                unsigned int bits = (allFlags&0x55555555) + ((allFlags>>1)&0x55555555);
                bits = (bits&0x33333333) + ((bits>>2)&0x33333333);
                bits = (bits&0x0F0F0F0F) + ((bits>>4)&0x0F0F0F0F);
                bits = (bits&0x00FF00FF) + ((bits>>8)&0x00FF00FF);
                bits = (bits&0x0000FFFF) + ((bits>>16)&0x0000FFFF);
                interactionFlags[pos] = (bits > 12 ? 0xFFFFFFFF : allFlags);
            }
            lasty = y;
        }
        pos++;
    }
}