customGBValueN2_nvidia.cl 12.3 KB
Newer Older
1
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
2
3
#define TILE_SIZE 32

4
5
6
7
/**
 * Compute a value based on pair interactions.
 */
__kernel void computeN2Value(__global float4* posq, __local float4* local_posq, __global unsigned int* exclusions,
8
        __global unsigned int* exclusionIndices, __global unsigned int* exclusionRowIndices, __global float* global_value, __local float* local_value,
9
        __local float* tempBuffer,
10
#ifdef USE_CUTOFF
11
        __global ushort2* tiles, __global unsigned int* interactionCount, float4 periodicBoxSize, float4 invPeriodicBoxSize, unsigned int maxTiles, __global unsigned int* interactionFlags
12
13
14
15
16
17
#else
        unsigned int numTiles
#endif
        PARAMETER_ARGUMENTS) {
    unsigned int totalWarps = get_global_size(0)/TILE_SIZE;
    unsigned int warp = get_global_id(0)/TILE_SIZE;
18
19
20
21
22
#ifdef USE_CUTOFF
    unsigned int numTiles = interactionCount[0];
    unsigned int pos = warp*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
    unsigned int end = (warp+1)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
#else
23
24
    unsigned int pos = warp*numTiles/totalWarps;
    unsigned int end = (warp+1)*numTiles/totalWarps;
25
#endif
26
27
    float energy = 0.0f;
    unsigned int lasty = 0xFFFFFFFF;
28
29
    __local unsigned int exclusionRange[2*WARPS_PER_GROUP];
    __local int exclusionIndex[WARPS_PER_GROUP];
30
31
32
    __local int2* reservedBlocks = (__local int2*) exclusionRange;
    
    do {
33
        // Extract the coordinates of this tile
34
35
36
        const unsigned int tgx = get_local_id(0) & (TILE_SIZE-1);
        const unsigned int tbx = get_local_id(0) - tgx;
        const unsigned int localGroupIndex = get_local_id(0)/TILE_SIZE;
37
        unsigned int x, y;
38
39
        float value = 0.0f;
        if (pos < end) {
40
#ifdef USE_CUTOFF
41
42
43
44
45
46
            if (numTiles <= maxTiles) {
                ushort2 tileIndices = tiles[pos];
                x = tileIndices.x;
                y = tileIndices.y;
            }
            else
47
#endif
48
49
            {
                y = (unsigned int) floor(NUM_BLOCKS+0.5f-SQRT((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
50
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
51
52
53
54
                if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                    y += (x < y ? -1 : 1);
                    x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
                }
55
            }
56
57
58
            unsigned int atom1 = x*TILE_SIZE + tgx;
            float4 posq1 = posq[atom1];
            LOAD_ATOM1_PARAMETERS
59

60
            // Locate the exclusion data for this tile.
61
62

#ifdef USE_EXCLUSIONS
63
64
65
66
67
68
69
70
            if (tgx < 2)
                exclusionRange[2*localGroupIndex+tgx] = exclusionRowIndices[x+tgx];
            if (tgx == 0)
                exclusionIndex[localGroupIndex] = -1;
            for (int i = exclusionRange[2*localGroupIndex]+tgx; i < exclusionRange[2*localGroupIndex+1]; i += TILE_SIZE)
                if (exclusionIndices[i] == y)
                    exclusionIndex[localGroupIndex] = i*TILE_SIZE;
            bool hasExclusions = (exclusionIndex[localGroupIndex] > -1);
71
#else
72
            bool hasExclusions = false;
73
#endif
74
75
76
77
            if (pos >= end)
                ; // This warp is done.
            else if (x == y) {
                // This tile is on the diagonal.
78

79
80
81
                const unsigned int localAtomIndex = get_local_id(0);
                local_posq[localAtomIndex] = posq1;
                LOAD_LOCAL_PARAMETERS_FROM_1
82
#ifdef USE_EXCLUSIONS
83
                unsigned int excl = exclusions[exclusionIndex[localGroupIndex]+tgx];
84
#endif
85
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
86
#ifdef USE_EXCLUSIONS
87
                    bool isExcluded = !(excl & 0x1);
88
#endif
89
90
91
                    int atom2 = tbx+j;
                    float4 posq2 = local_posq[atom2];
                    float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
92
#ifdef USE_PERIODIC
93
94
95
                    delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                    delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                    delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
96
#endif
97
                    float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
98
#ifdef USE_CUTOFF
99
                    if (r2 < CUTOFF_SQUARED) {
100
#endif
101
102
103
104
105
                    float r = SQRT(r2);
                    LOAD_ATOM2_PARAMETERS
                    atom2 = y*TILE_SIZE+j;
                    float tempValue1 = 0.0f;
                    float tempValue2 = 0.0f;
106
#ifdef USE_EXCLUSIONS
107
                    if (!isExcluded && atom1 < NUM_ATOMS && atom2 < NUM_ATOMS && atom1 != atom2) {
108
#else
109
                    if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS && atom1 != atom2) {
110
#endif
111
112
113
                        COMPUTE_VALUE
                    }
                    value += tempValue1;
114
#ifdef USE_CUTOFF
115
                    }
116
#endif
117
#ifdef USE_EXCLUSIONS
118
                    excl >>= 1;
119
#endif
120
                }
121
            }
122
123
            else {
                // This is an off-diagonal tile.
124

125
126
127
128
129
                if (lasty != y) {
                    unsigned int j = y*TILE_SIZE + tgx;
                    local_posq[get_local_id(0)] = posq[j];
                    const unsigned int localAtomIndex = get_local_id(0);
                    LOAD_LOCAL_PARAMETERS_FROM_GLOBAL
130
                }
131
132
133
134
135
136
137
138
139
                local_value[get_local_id(0)] = 0.0f;
#ifdef USE_CUTOFF
                unsigned int flags = (numTiles <= maxTiles ? interactionFlags[pos] : 0xFFFFFFFF);
                if (!hasExclusions && flags != 0xFFFFFFFF) {
                    if (flags == 0) {
                        // No interactions in this tile.
                    }
                    else {
                        // Compute only a subset of the interactions in this tile.
140

141
142
143
144
145
                        for (unsigned int j = 0; j < TILE_SIZE; j++) {
                            if ((flags&(1<<j)) != 0) {
                                int atom2 = tbx+j;
                                float4 posq2 = local_posq[atom2];
                                float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
146
#ifdef USE_PERIODIC
147
148
149
                                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
150
#endif
151
152
153
154
155
156
157
158
159
160
161
                                float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                                float tempValue1 = 0.0f;
                                float tempValue2 = 0.0f;
                                if (r2 < CUTOFF_SQUARED) {
                                    float r = SQRT(r2);
                                    LOAD_ATOM2_PARAMETERS
                                    atom2 = y*TILE_SIZE+j;
                                    if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
                                        COMPUTE_VALUE
                                    }
                                    value += tempValue1;
162
                                }
163
                                tempBuffer[get_local_id(0)] = tempValue2;
164

165
                                // Sum the forces on atom2.
166

167
168
169
170
171
172
173
174
175
176
177
                                if (tgx % 2 == 0)
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+1];
                                if (tgx % 4 == 0)
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+2];
                                if (tgx % 8 == 0)
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+4];
                                if (tgx % 16 == 0)
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+8];
                                if (tgx == 0)
                                    local_value[tbx+j] += tempBuffer[get_local_id(0)] + tempBuffer[get_local_id(0)+16];
                            }
178
179
180
                        }
                    }
                }
181
                else
182
#endif
183
184
                {
                    // Compute the full set of interactions in this tile.
185
186

#ifdef USE_EXCLUSIONS
187
188
                    unsigned int excl = (hasExclusions ? exclusions[exclusionIndex[localGroupIndex]+tgx] : 0xFFFFFFFF);
                    excl = (excl >> tgx) | (excl << (TILE_SIZE - tgx));
189
#endif
190
191
                    unsigned int tj = tgx;
                    for (unsigned int j = 0; j < TILE_SIZE; j++) {
192
#ifdef USE_EXCLUSIONS
193
                        bool isExcluded = !(excl & 0x1);
194
#endif
195
196
197
                        int atom2 = tbx+tj;
                        float4 posq2 = local_posq[atom2];
                        float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
198
#ifdef USE_PERIODIC
199
200
201
                        delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                        delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                        delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
202
#endif
203
                        float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
204
#ifdef USE_CUTOFF
205
                        if (r2 < CUTOFF_SQUARED) {
206
#endif
207
208
209
210
211
                        float r = SQRT(r2);
                        LOAD_ATOM2_PARAMETERS
                        atom2 = y*TILE_SIZE+tj;
                        float tempValue1 = 0.0f;
                        float tempValue2 = 0.0f;
212
#ifdef USE_EXCLUSIONS
213
                        if (!isExcluded && atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
214
#else
215
                        if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
216
#endif
217
218
219
220
                            COMPUTE_VALUE
                        }
                        value += tempValue1;
                        local_value[tbx+tj] += tempValue2;
221
#ifdef USE_CUTOFF
222
                        }
223
#endif
224
#ifdef USE_EXCLUSIONS
225
                        excl >>= 1;
226
#endif
227
228
                        tj = (tj + 1) & (TILE_SIZE - 1);
                    }
229
230
                }
            }
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        }
        
        // Write results.  We need to coordinate between warps to make sure no two of them
        // ever try to write to the same piece of memory at the same time.
        
        int writeX = (pos < end ? x : -1);
        int writeY = (pos < end && x != y ? y : -1);
        if (tgx == 0)
            reservedBlocks[localGroupIndex] = (int2)(writeX, writeY);
        bool done = false;
        int doneIndex = 0;
        int checkIndex = 0;
        while (true) {
            // See if any warp still needs to write its data.
245

246
247
248
249
250
251
252
253
254
255
256
257
258
            bool allDone = true;
            barrier(CLK_LOCAL_MEM_FENCE);
            while (doneIndex < WARPS_PER_GROUP && allDone) {
                if (reservedBlocks[doneIndex].x != -1)
                    allDone = false;
                else
                    doneIndex++;
            }
            if (allDone)
                break;
            if (!done) {
                // See whether this warp can write its data.  This requires that no previous warp
                // is trying to write to the same block of the buffer.
259

260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
                bool canWrite = (writeX != -1);
                while (checkIndex < localGroupIndex && canWrite) {
                    if ((reservedBlocks[checkIndex].x == x || reservedBlocks[checkIndex].y == x) ||
                            (writeY != -1 && (reservedBlocks[checkIndex].x == y || reservedBlocks[checkIndex].y == y)))
                        canWrite = false;
                    else
                        checkIndex++;
                }
                if (canWrite) {
                    // Write the data to global memory, then mark this warp as done.

                    if (writeX > -1) {
                        const unsigned int offset = x*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        global_value[offset] += value;
                    }
                    if (writeY > -1) {
                        const unsigned int offset = y*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        global_value[offset] += local_value[get_local_id(0)];
                    }
                    done = true;
                    if (tgx == 0)
                        reservedBlocks[localGroupIndex] = (int2)(-1, -1);
                }
            }
284
        }
285
        lasty = y;
286
        pos++;
287
    } while (pos < end);
288
}