findInteractingBlocks.cl 10.4 KB
Newer Older
1
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
2
#define TILE_SIZE 32
3
4
5
#define GROUP_SIZE 64
#define BUFFER_GROUPS 4
#define BUFFER_SIZE BUFFER_GROUPS*GROUP_SIZE
6
7
8
9

/**
 * Find a bounding box for the atoms in each block.
 */
10
__kernel void findBlockBounds(int numAtoms, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global float4* posq, __global float4* blockCenter, __global float4* blockBoundingBox, __global unsigned int* interactionCount) {
11
    int index = get_global_id(0);
12
    int base = index*TILE_SIZE;
13
14
15
    while (base < numAtoms) {
        float4 pos = posq[base];
#ifdef USE_PERIODIC
16
17
18
        pos.x -= floor(pos.x*invPeriodicBoxSize.x)*periodicBoxSize.x;
        pos.y -= floor(pos.y*invPeriodicBoxSize.y)*periodicBoxSize.y;
        pos.z -= floor(pos.z*invPeriodicBoxSize.z)*periodicBoxSize.z;
19
20
21
22
        float4 firstPoint = pos;
#endif
        float4 minPos = pos;
        float4 maxPos = pos;
23
        int last = min(base+TILE_SIZE, numAtoms);
24
25
26
        for (int i = base+1; i < last; i++) {
            pos = posq[i];
#ifdef USE_PERIODIC
27
28
29
            pos.x -= floor((pos.x-firstPoint.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
            pos.y -= floor((pos.y-firstPoint.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
            pos.z -= floor((pos.z-firstPoint.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
30
31
32
33
34
35
36
#endif
            minPos = min(minPos, pos);
            maxPos = max(maxPos, pos);
        }
        blockBoundingBox[index] = 0.5f*(maxPos-minPos);
        blockCenter[index] = 0.5f*(maxPos+minPos);
        index += get_global_size(0);
37
        base = index*TILE_SIZE;
38
    }
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    if (get_global_id(0) == 0)
        interactionCount[0] = 0;
}

/**
 * This is called by findBlocksWithInteractions().  It compacts the list of blocks and writes them
 * to global memory.
 */
void storeInteractionData(__local short2* buffer, __local bool* valid, __local int* sum, __local int* sum2, __local short2* temp, __local int* baseIndex,
            __global unsigned int* interactionCount, __global unsigned int* interactingTiles) {
    // The buffer is full, so we need to compact it and write out results.  Start by doing a parallel prefix sum.

    for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
        sum[i] = (valid[i] ? 1 : 0);
    barrier(CLK_LOCAL_MEM_FENCE);
    int whichBuffer = 0;
    for (int offset = 1; offset < BUFFER_SIZE; offset *= 2) {
        if (whichBuffer == 0)
            for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
                sum2[i] = (i < offset ? sum[i] : sum[i]+sum[i-offset]);
        else
            for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
                sum[i] = (i < offset ? sum2[i] : sum2[i]+sum2[i-offset]);
        whichBuffer = 1-whichBuffer;
        barrier(CLK_LOCAL_MEM_FENCE);
    }
    if (whichBuffer == 1) {
        for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
            sum[i] = sum2[i];
        barrier(CLK_LOCAL_MEM_FENCE);
    }

    // Compact the buffer and store it to global memory.

    for (int i = get_local_id(0); i < BUFFER_SIZE; i += GROUP_SIZE)
        if (valid[i]) {
            temp[sum[i]-1] = buffer[i];
            valid[i] = false;
        }
    barrier(CLK_LOCAL_MEM_FENCE);
    int numValid = sum[BUFFER_SIZE-1];
    if (get_local_id(0) == 0)
        *baseIndex = atom_add(interactionCount, numValid);
    barrier(CLK_LOCAL_MEM_FENCE);

    // Store it to global memory.

    for (int i = get_local_id(0); i < numValid; i += GROUP_SIZE)
        interactingTiles[*baseIndex+i] = (temp[i].x<<17)+(temp[i].y<<2);
    barrier(CLK_LOCAL_MEM_FENCE);
89
90
91
92
93
94
}

/**
 * Compare the bounding boxes for each pair of blocks.  If they are sufficiently far apart,
 * mark them as non-interacting.
 */
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
__kernel void findBlocksWithInteractions(float cutoffSquared, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global float4* blockCenter,
        __global float4* blockBoundingBox, __global unsigned int* interactionCount, __global unsigned int* interactingTiles) {
    __local short2 buffer[BUFFER_SIZE];
    __local bool valid[BUFFER_SIZE];
    __local int sum[BUFFER_SIZE];
    __local int sum2[BUFFER_SIZE];
    __local short2 temp[BUFFER_SIZE];
    __local int bufferFull;
    __local int globalIndex;
    int valuesInBuffer = 0;
    if (get_local_id(0) == 0)
        bufferFull = false;
    for (int i = 0; i < BUFFER_GROUPS; ++i)
        valid[i*GROUP_SIZE+get_local_id(0)] = false;
    barrier(CLK_LOCAL_MEM_FENCE);
    const int numTiles = (NUM_BLOCKS*(NUM_BLOCKS+1))/2;
    for (int baseIndex = get_group_id(0)*get_local_size(0); baseIndex < numTiles; baseIndex += get_global_size(0)) {
        // Identify the pair of blocks to compare.
113

114
115
116
117
118
119
120
121
        int index = baseIndex+get_local_id(0);
        if (index < numTiles) {
            unsigned int y = (unsigned int) floor(NUM_BLOCKS+0.5f-sqrt((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*index));
            unsigned int x = (index-y*NUM_BLOCKS+y*(y+1)/2);
            if (x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                y++;
                x = (index-y*NUM_BLOCKS+y*(y+1)/2);
            }
122

123
            // Find the distance between the bounding boxes of the two cells.
124

125
            float4 delta = blockCenter[x]-blockCenter[y];
126
#ifdef USE_PERIODIC
127
128
129
            delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
            delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
            delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
130
#endif
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            float4 boxSizea = blockBoundingBox[x];
            float4 boxSizeb = blockBoundingBox[y];
            delta.x = max(0.0f, fabs(delta.x)-boxSizea.x-boxSizeb.x);
            delta.y = max(0.0f, fabs(delta.y)-boxSizea.y-boxSizeb.y);
            delta.z = max(0.0f, fabs(delta.z)-boxSizea.z-boxSizeb.z);
            if (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z < cutoffSquared) {
                // Add this tile to the buffer.

                int bufferIndex = valuesInBuffer*GROUP_SIZE+get_local_id(0);
                valid[bufferIndex] = true;
                buffer[bufferIndex] = (short2) (x, y);
                valuesInBuffer++;
                if (!bufferFull && valuesInBuffer == BUFFER_GROUPS)
                    bufferFull = true;
            }
        }
        barrier(CLK_LOCAL_MEM_FENCE);
        if (bufferFull) {
            storeInteractionData(buffer, valid, sum, sum2, temp, &globalIndex, interactionCount, interactingTiles);
            valuesInBuffer = 0;
            if (get_local_id(0) == 0)
                bufferFull = false;
            barrier(CLK_LOCAL_MEM_FENCE);
        }
155
    }
156
    storeInteractionData(buffer, valid, sum, sum2, temp, &globalIndex, interactionCount, interactingTiles);
157
158
159
160
161
162
}

/**
 * Compare each atom in one block to the bounding box of another block, and set
 * flags for which ones are interacting.
 */
163
__kernel void findInteractionsWithinBlocks(float cutoffSquared, float4 periodicBoxSize, float4 invPeriodicBoxSize, __global float4* posq, __global unsigned int* tiles, __global float4* blockCenter,
164
            __global float4* blockBoundingBox, __global unsigned int* interactionFlags, __global unsigned int* interactionCount, __local unsigned int* flags) {
165
166
    unsigned int totalWarps = get_global_size(0)/TILE_SIZE;
    unsigned int warp = get_global_id(0)/TILE_SIZE;
167
168
169
    unsigned int numTiles = interactionCount[0];
    unsigned int pos = warp*numTiles/totalWarps;
    unsigned int end = (warp+1)*numTiles/totalWarps;
170
    unsigned int index = get_local_id(0) & (TILE_SIZE - 1);
171
172
173

    unsigned int lasty = 0xFFFFFFFF;
    float4 apos;
174
175
    while (pos < end) {
        // Extract the coordinates of this tile
176
177
        unsigned int x = tiles[pos];
        unsigned int y = ((x >> 2) & 0x7fff);
178
        bool hasExclusions = (x & 0x1);
179
        x = (x >> 17);
180
181
        if (x == y || hasExclusions) {
            // Assume this tile will be dense.
182
183
184
185

            if (index == 0)
                interactionFlags[pos] = 0xFFFFFFFF;
        }
186
        else {
187
188
189
190
191
            // Load the bounding box for x and the atom positions for y.

            float4 center = blockCenter[x];
            float4 boxSize = blockBoundingBox[x];
            if (y != lasty)
192
                apos = posq[y*TILE_SIZE+index];
193
194
195
196
197

            // Find the distance of the atom from the bounding box.

            float4 delta = apos-center;
#ifdef USE_PERIODIC
198
199
200
            delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
            delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
            delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
201
202
203
204
205
206
207
#endif
            delta = max((float4) 0.0f, fabs(delta)-boxSize);
            int thread = get_local_id(0);
            flags[thread] = (delta.x*delta.x+delta.y*delta.y+delta.z*delta.z > cutoffSquared ? 0 : 1 << index);

            // Sum the flags.

208
#ifdef WARPS_ARE_ATOMIC
209
210
211
212
213
214
215
216
            if (index % 2 == 0)
                flags[thread] += flags[thread+1];
            if (index % 4 == 0)
                flags[thread] += flags[thread+2];
            if (index % 8 == 0)
                flags[thread] += flags[thread+4];
            if (index % 16 == 0)
                flags[thread] += flags[thread+8];
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#else
            barrier(CLK_LOCAL_MEM_FENCE);
            if (index % 2 == 0)
                flags[thread] += flags[thread+1];
            barrier(CLK_LOCAL_MEM_FENCE);
            if (index % 4 == 0)
                flags[thread] += flags[thread+2];
            barrier(CLK_LOCAL_MEM_FENCE);
            if (index % 8 == 0)
                flags[thread] += flags[thread+4];
            barrier(CLK_LOCAL_MEM_FENCE);
            if (index % 16 == 0)
                flags[thread] += flags[thread+8];
            barrier(CLK_LOCAL_MEM_FENCE);
#endif
232
            if (index == 0) {
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
                unsigned int allFlags = flags[thread] + flags[thread+16];

                // Count how many flags are set, and based on that decide whether to compute all interactions
                // or only a fraction of them.

                unsigned int bits = (allFlags&0x55555555) + ((allFlags>>1)&0x55555555);
                bits = (bits&0x33333333) + ((bits>>2)&0x33333333);
                bits = (bits&0x0F0F0F0F) + ((bits>>4)&0x0F0F0F0F);
                bits = (bits&0x00FF00FF) + ((bits>>8)&0x00FF00FF);
                bits = (bits&0x0000FFFF) + ((bits>>16)&0x0000FFFF);
                interactionFlags[pos] = (bits > 12 ? 0xFFFFFFFF : allFlags);
            }
            lasty = y;
        }
        pos++;
    }
}