gbsaObc_nvidia.cl 31 KB
Newer Older
1
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
2
3
4
#ifdef SUPPORTS_64_BIT_ATOMICS
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
#endif
5
#define TILE_SIZE 32
6

7
8
9
10
11
12
13
14
15
16
typedef struct {
    float x, y, z;
    float q;
    float fx, fy, fz, fw;
    float radius, scaledRadius;
    float bornSum;
    float bornRadius;
    float bornForce;
} AtomData;

17
18
19
/**
 * Compute the Born sum.
 */
20
21
22
23
24
25
26
__kernel void computeBornSum(
#ifdef SUPPORTS_64_BIT_ATOMICS
        __global long* global_bornSum,
#else
        __global float* global_bornSum,
#endif
        __global float4* posq, __global float2* global_params,
27
        __local AtomData* localData, __local float* tempBuffer,
28
#ifdef USE_CUTOFF
29
        __global ushort2* tiles, __global unsigned int* interactionCount, float4 periodicBoxSize, float4 invPeriodicBoxSize, unsigned int maxTiles, __global unsigned int* interactionFlags) {
30
31
32
#else
        unsigned int numTiles) {
#endif
33
34
    unsigned int totalWarps = get_global_size(0)/TILE_SIZE;
    unsigned int warp = get_global_id(0)/TILE_SIZE;
35
36
#ifdef USE_CUTOFF
    unsigned int numTiles = interactionCount[0];
37
38
    unsigned int pos = warp*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
    unsigned int end = (warp+1)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
39
#else
40
41
    unsigned int pos = warp*numTiles/totalWarps;
    unsigned int end = (warp+1)*numTiles/totalWarps;
42
#endif
43
    unsigned int lasty = 0xFFFFFFFF;
44
45
46
    __local int2 reservedBlocks[WARPS_PER_GROUP];
    
    do {
47
        // Extract the coordinates of this tile
48
49
50
        const unsigned int tgx = get_local_id(0) & (TILE_SIZE-1);
        const unsigned int tbx = get_local_id(0) - tgx;
        const unsigned int localGroupIndex = get_local_id(0)/TILE_SIZE;
51
        unsigned int x, y;
52
53
        float bornSum = 0.0f;
        if (pos < end) {
54
#ifdef USE_CUTOFF
55
56
57
58
59
60
            if (numTiles <= maxTiles) {
                ushort2 tileIndices = tiles[pos];
                x = tileIndices.x;
                y = tileIndices.y;
            }
            else
61
#endif
62
63
            {
                y = (unsigned int) floor(NUM_BLOCKS+0.5f-sqrt((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
64
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
65
66
67
68
                if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                    y += (x < y ? -1 : 1);
                    x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
                }
69
            }
70
71
72
73
74
75
76
            unsigned int atom1 = x*TILE_SIZE + tgx;
            float4 posq1 = posq[atom1];
            float2 params1 = global_params[atom1];
            if (pos >= end)
                ; // This warp is done.
            else if (x == y) {
                // This tile is on the diagonal.
77

78
79
80
81
82
83
84
85
                localData[get_local_id(0)].x = posq1.x;
                localData[get_local_id(0)].y = posq1.y;
                localData[get_local_id(0)].z = posq1.z;
                localData[get_local_id(0)].q = posq1.w;
                localData[get_local_id(0)].radius = params1.x;
                localData[get_local_id(0)].scaledRadius = params1.y;
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
                    float4 delta = (float4) (localData[tbx+j].x-posq1.x, localData[tbx+j].y-posq1.y, localData[tbx+j].z-posq1.z, 0.0f);
86
#ifdef USE_PERIODIC
87
88
89
                    delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                    delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                    delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
90
#endif
91
                    float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
92
#ifdef USE_CUTOFF
93
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
94
#else
95
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
96
#endif
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
                        float invR = RSQRT(r2);
                        float r = RECIP(invR);
                        float2 params2 = (float2) (localData[tbx+j].radius, localData[tbx+j].scaledRadius);
                        float rScaledRadiusJ = r+params2.y;
                        if ((j != tgx) && (params1.x < rScaledRadiusJ)) {
                            float l_ij = RECIP(max(params1.x, fabs(r-params2.y)));
                            float u_ij = RECIP(rScaledRadiusJ);
                            float l_ij2 = l_ij*l_ij;
                            float u_ij2 = u_ij*u_ij;
                            float ratio = LOG(u_ij * RECIP(l_ij));
                            bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
                                             (0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2);
                            if (params1.x < params2.x-r)
                                bornSum += 2.0f*(RECIP(params1.x)-l_ij);
                        }
112
113
114
                    }
                }
            }
115
116
            else {
                // This is an off-diagonal tile.
117

118
119
120
121
122
123
124
125
126
127
                if (lasty != y) {
                    unsigned int j = y*TILE_SIZE + tgx;
                    float4 tempPosq = posq[j];
                    localData[get_local_id(0)].x = tempPosq.x;
                    localData[get_local_id(0)].y = tempPosq.y;
                    localData[get_local_id(0)].z = tempPosq.z;
                    localData[get_local_id(0)].q = tempPosq.w;
                    float2 tempParams = global_params[j];
                    localData[get_local_id(0)].radius = tempParams.x;
                    localData[get_local_id(0)].scaledRadius = tempParams.y;
128
                }
129
130
131
132
133
134
135
136
137
                localData[get_local_id(0)].bornSum = 0.0f;
#ifdef USE_CUTOFF
                unsigned int flags = (numTiles <= maxTiles ? interactionFlags[pos] : 0xFFFFFFFF);
                if (flags != 0xFFFFFFFF && false) { // TODO: Fix this: should be checking for exclusions
                    if (flags == 0) {
                        // No interactions in this tile.
                    }
                    else {
                        // Compute only a subset of the interactions in this tile.
138

139
140
141
                        for (unsigned int j = 0; j < TILE_SIZE; j++) {
                            if ((flags&(1<<j)) != 0) {
                                float4 delta = (float4) (localData[tbx+j].x-posq1.x, localData[tbx+j].y-posq1.y, localData[tbx+j].z-posq1.z, 0.0f);
142
#ifdef USE_PERIODIC
143
144
145
                                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
146
#endif
147
148
                                float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                                tempBuffer[get_local_id(0)] = 0.0f;
149
#ifdef USE_CUTOFF
150
                                if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
151
#else
152
                                if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
153
#endif
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
                                    float invR = RSQRT(r2);
                                    float r = RECIP(invR);
                                    float2 params2 = (float2) (localData[tbx+j].radius, localData[tbx+j].scaledRadius);
                                    float rScaledRadiusJ = r+params2.y;
                                    if (params1.x < rScaledRadiusJ) {
                                        float l_ij = RECIP(max(params1.x, fabs(r-params2.y)));
                                        float u_ij = RECIP(rScaledRadiusJ);
                                        float l_ij2 = l_ij*l_ij;
                                        float u_ij2 = u_ij*u_ij;
                                        float ratio = LOG(u_ij * RECIP(l_ij));
                                        bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
                                                         (0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2);
                                        if (params1.x < params2.x-r)
                                            bornSum += 2.0f*(RECIP(params1.x)-l_ij);
                                    }
                                    float rScaledRadiusI = r+params1.y;
                                    if (params2.x < rScaledRadiusI) {
                                        float l_ij = RECIP(max(params2.x, fabs(r-params1.y)));
                                        float u_ij = RECIP(rScaledRadiusI);
                                        float l_ij2 = l_ij*l_ij;
                                        float u_ij2 = u_ij*u_ij;
                                        float ratio = LOG(u_ij * RECIP(l_ij));
                                        float term = l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
                                                         (0.25f*params1.y*params1.y*invR)*(l_ij2-u_ij2);
                                        if (params2.x < params1.x-r)
                                            term += 2.0f*(RECIP(params2.x)-l_ij);
                                        tempBuffer[get_local_id(0)] = term;
                                    }
182
183
                                }

184
                                // Sum the forces on atom j.
185

186
                                if (tgx % 4 == 0)
187
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+1]+tempBuffer[get_local_id(0)+2]+tempBuffer[get_local_id(0)+3];
188
                                if (tgx == 0)
189
                                    localData[tbx+j].bornSum += tempBuffer[get_local_id(0)]+tempBuffer[get_local_id(0)+4]+tempBuffer[get_local_id(0)+8]+tempBuffer[get_local_id(0)+12]+tempBuffer[get_local_id(0)+16]+tempBuffer[get_local_id(0)+20]+tempBuffer[get_local_id(0)+24]+tempBuffer[get_local_id(0)+28];
190
                            }
191
192
193
                        }
                    }
                }
194
                else
195
#endif
196
197
                {
                    // Compute the full set of interactions in this tile.
198

199
200
201
                    unsigned int tj = tgx;
                    for (unsigned int j = 0; j < TILE_SIZE; j++) {
                        float4 delta = (float4) (localData[tbx+tj].x-posq1.x, localData[tbx+tj].y-posq1.y, localData[tbx+tj].z-posq1.z, 0.0f);
202
#ifdef USE_PERIODIC
203
204
205
                        delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                        delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                        delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
206
#endif
207
                        float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
208
#ifdef USE_CUTOFF
209
                        if (atom1 < NUM_ATOMS && y*TILE_SIZE+tj < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
210
#else
211
                        if (atom1 < NUM_ATOMS && y*TILE_SIZE+tj < NUM_ATOMS) {
212
#endif
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
                            float invR = RSQRT(r2);
                            float r = RECIP(invR);
                            float2 params2 = (float2) (localData[tbx+tj].radius, localData[tbx+tj].scaledRadius);
                            float rScaledRadiusJ = r+params2.y;
                            if (params1.x < rScaledRadiusJ) {
                                float l_ij = RECIP(max(params1.x, fabs(r-params2.y)));
                                float u_ij = RECIP(rScaledRadiusJ);
                                float l_ij2 = l_ij*l_ij;
                                float u_ij2 = u_ij*u_ij;
                                float ratio = LOG(u_ij * RECIP(l_ij));
                                bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
                                                 (0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2);
                                if (params1.x < params2.x-r)
                                    bornSum += 2.0f*(RECIP(params1.x)-l_ij);
                            }
                            float rScaledRadiusI = r+params1.y;
                            if (params2.x < rScaledRadiusI) {
                                float l_ij = RECIP(max(params2.x, fabs(r-params1.y)));
                                float u_ij = RECIP(rScaledRadiusI);
                                float l_ij2 = l_ij*l_ij;
                                float u_ij2 = u_ij*u_ij;
                                float ratio = LOG(u_ij * RECIP(l_ij));
                                float term = l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
                                                 (0.25f*params1.y*params1.y*invR)*(l_ij2-u_ij2);
                                if (params2.x < params1.x-r)
                                    term += 2.0f*(RECIP(params2.x)-l_ij);
                                localData[tbx+tj].bornSum += term;
                            }
241
                        }
242
                        tj = (tj + 1) & (TILE_SIZE - 1);
243
244
245
                    }
                }
            }
246
247
248
249
250
        }
        
        // Write results.  We need to coordinate between warps to make sure no two of them
        // ever try to write to the same piece of memory at the same time.
        
251
252
253
254
255
256
257
258
259
260
#ifdef SUPPORTS_64_BIT_ATOMICS
        if (pos < end) {
            const unsigned int offset = x*TILE_SIZE + tgx;
            atom_add(&global_bornSum[offset], (long) (bornSum*0xFFFFFFFF));
        }
        if (pos < end && x != y) {
            const unsigned int offset = y*TILE_SIZE + tgx;
            atom_add(&global_bornSum[offset], (long) (localData[get_local_id(0)].bornSum*0xFFFFFFFF));
        }
#else
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        int writeX = (pos < end ? x : -1);
        int writeY = (pos < end && x != y ? y : -1);
        if (tgx == 0)
            reservedBlocks[localGroupIndex] = (int2)(writeX, writeY);
        bool done = false;
        int doneIndex = 0;
        int checkIndex = 0;
        while (true) {
            // See if any warp still needs to write its data.

            bool allDone = true;
            barrier(CLK_LOCAL_MEM_FENCE);
            while (doneIndex < WARPS_PER_GROUP && allDone) {
                if (reservedBlocks[doneIndex].x != -1)
                    allDone = false;
                else
                    doneIndex++;
            }
            if (allDone)
                break;
            if (!done) {
                // See whether this warp can write its data.  This requires that no previous warp
                // is trying to write to the same block of the buffer.
284

285
286
287
288
289
290
291
292
293
294
                bool canWrite = (writeX != -1);
                while (checkIndex < localGroupIndex && canWrite) {
                    if ((reservedBlocks[checkIndex].x == x || reservedBlocks[checkIndex].y == x) ||
                            (writeY != -1 && (reservedBlocks[checkIndex].x == y || reservedBlocks[checkIndex].y == y)))
                        canWrite = false;
                    else
                        checkIndex++;
                }
                if (canWrite) {
                    // Write the data to global memory, then mark this warp as done.
295

296
297
298
299
300
301
302
303
304
305
306
307
308
                    if (writeX > -1) {
                        const unsigned int offset = x*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        global_bornSum[offset] += bornSum;
                    }
                    if (writeY > -1) {
                        const unsigned int offset = y*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        global_bornSum[offset] += localData[get_local_id(0)].bornSum;
                    }
                    done = true;
                    if (tgx == 0)
                        reservedBlocks[localGroupIndex] = (int2)(-1, -1);
                }
            }
309
        }
310
#endif
311
        lasty = y;
312
        pos++;
313
    } while (pos < end);
314
315
}

316
317
318
319
/**
 * First part of computing the GBSA interaction.
 */

320
321
322
323
324
325
326
__kernel void computeGBSAForce1(
#ifdef SUPPORTS_64_BIT_ATOMICS
        __global long* forceBuffers, __global long* global_bornForce,
#else
        __global float4* forceBuffers, __global float* global_bornForce,
#endif
        __global float* energyBuffer, __global float4* posq, __global float* global_bornRadii,
327
        __local AtomData* localData, __local float4* tempBuffer,
328
#ifdef USE_CUTOFF
329
        __global ushort2* tiles, __global unsigned int* interactionCount, float4 periodicBoxSize, float4 invPeriodicBoxSize, unsigned int maxTiles, __global unsigned int* interactionFlags) {
330
331
332
#else
        unsigned int numTiles) {
#endif
333
334
    unsigned int totalWarps = get_global_size(0)/TILE_SIZE;
    unsigned int warp = get_global_id(0)/TILE_SIZE;
335
336
#ifdef USE_CUTOFF
    unsigned int numTiles = interactionCount[0];
337
338
    unsigned int pos = warp*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
    unsigned int end = (warp+1)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
339
#else
340
341
    unsigned int pos = warp*numTiles/totalWarps;
    unsigned int end = (warp+1)*numTiles/totalWarps;
342
#endif
343
344
    float energy = 0.0f;
    unsigned int lasty = 0xFFFFFFFF;
345
346
347
    __local int2 reservedBlocks[WARPS_PER_GROUP];
    
    do {
348
        // Extract the coordinates of this tile
349
350
351
        const unsigned int tgx = get_local_id(0) & (TILE_SIZE-1);
        const unsigned int tbx = get_local_id(0) - tgx;
        const unsigned int localGroupIndex = get_local_id(0)/TILE_SIZE;
352
        unsigned int x, y;
353
354
        float4 force = 0.0f;
        if (pos < end) {
355
#ifdef USE_CUTOFF
356
357
358
359
360
361
            if (numTiles <= maxTiles) {
                ushort2 tileIndices = tiles[pos];
                x = tileIndices.x;
                y = tileIndices.y;
            }
            else
362
#endif
363
364
            {
                y = (unsigned int) floor(NUM_BLOCKS+0.5f-sqrt((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
365
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
366
367
368
369
                if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                    y += (x < y ? -1 : 1);
                    x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
                }
370
            }
371
372
373
374
375
            unsigned int atom1 = x*TILE_SIZE + tgx;
            float4 posq1 = posq[atom1];
            float bornRadius1 = global_bornRadii[atom1];
            if (x == y) {
                // This tile is on the diagonal.
376

377
378
379
380
381
382
383
384
385
                localData[get_local_id(0)].x = posq1.x;
                localData[get_local_id(0)].y = posq1.y;
                localData[get_local_id(0)].z = posq1.z;
                localData[get_local_id(0)].q = posq1.w;
                localData[get_local_id(0)].bornRadius = bornRadius1;
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
                        float4 posq2 = (float4) (localData[tbx+j].x, localData[tbx+j].y, localData[tbx+j].z, localData[tbx+j].q);
                        float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
386
#ifdef USE_PERIODIC
387
388
389
                        delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                        delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                        delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
390
#endif
391
392
393
394
395
396
397
398
399
400
401
402
403
                        float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                        float invR = RSQRT(r2);
                        float r = RECIP(invR);
                        float bornRadius2 = localData[tbx+j].bornRadius;
                        float alpha2_ij = bornRadius1*bornRadius2;
                        float D_ij = r2*RECIP(4.0f*alpha2_ij);
                        float expTerm = EXP(-D_ij);
                        float denominator2 = r2 + alpha2_ij*expTerm;
                        float denominator = SQRT(denominator2);
                        float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
                        float Gpol = tempEnergy*RECIP(denominator2);
                        float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
                        float dEdR = Gpol*(1.0f - 0.25f*expTerm);
404
#ifdef USE_CUTOFF
405
406
407
408
409
                        if (r2 > CUTOFF_SQUARED) {
                            dEdR = 0.0f;
                            tempEnergy  = 0.0f;
                            dGpol_dalpha2_ij = 0.0f;
                        }
410
#endif
411
412
413
414
415
                        force.w += dGpol_dalpha2_ij*bornRadius2;
                        energy += 0.5f*tempEnergy;
                        delta.xyz *= dEdR;
                        force.xyz -= delta.xyz;
                    }
416
417
                }
            }
418
419
            else {
                // This is an off-diagonal tile.
420

421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
                if (lasty != y) {
                    unsigned int j = y*TILE_SIZE + tgx;
                    float4 tempPosq = posq[j];
                    localData[get_local_id(0)].x = tempPosq.x;
                    localData[get_local_id(0)].y = tempPosq.y;
                    localData[get_local_id(0)].z = tempPosq.z;
                    localData[get_local_id(0)].q = tempPosq.w;
                    localData[get_local_id(0)].bornRadius = global_bornRadii[j];
                }
                localData[get_local_id(0)].fx = 0.0f;
                localData[get_local_id(0)].fy = 0.0f;
                localData[get_local_id(0)].fz = 0.0f;
                localData[get_local_id(0)].fw = 0.0f;
#ifdef USE_CUTOFF
                unsigned int flags = (numTiles <= maxTiles ? interactionFlags[pos] : 0xFFFFFFFF);
                if (flags != 0xFFFFFFFF && false) { // TODO: Fix this: should be checking for exclusions
                    if (flags == 0) {
                        // No interactions in this tile.
                    }
                    else {
                        // Compute only a subset of the interactions in this tile.
442

443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
                        for (unsigned int j = 0; j < TILE_SIZE; j++) {
                            if ((flags&(1<<j)) != 0) {
                                float4 posq2 = (float4) (localData[tbx+j].x, localData[tbx+j].y, localData[tbx+j].z, localData[tbx+j].q);
                                float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
#ifdef USE_PERIODIC
                                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
                                float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                                float invR = RSQRT(r2);
                                float r = RECIP(invR);
                                float bornRadius2 = localData[tbx+j].bornRadius;
                                float alpha2_ij = bornRadius1*bornRadius2;
                                float D_ij = r2*RECIP(4.0f*alpha2_ij);
                                float expTerm = EXP(-D_ij);
                                float denominator2 = r2 + alpha2_ij*expTerm;
                                float denominator = SQRT(denominator2);
                                float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
                                float Gpol = tempEnergy*RECIP(denominator2);
                                float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
                                float dEdR = Gpol*(1.0f - 0.25f*expTerm);
465
#ifdef USE_CUTOFF
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
                                if (atom1 >= NUM_ATOMS || y*TILE_SIZE+j >= NUM_ATOMS || r2 > CUTOFF_SQUARED) {
#else
                                if (atom1 >= NUM_ATOMS || y*TILE_SIZE+j >= NUM_ATOMS) {
#endif
                                    dEdR = 0.0f;
                                    dGpol_dalpha2_ij = 0.0f;
                                    tempEnergy = 0.0f;
                                }
                                energy += tempEnergy;
                                force.w += dGpol_dalpha2_ij*bornRadius2;
                                delta.xyz *= dEdR;
                                force.xyz -= delta.xyz;
                                tempBuffer[get_local_id(0)] = (float4) (delta.xyz, dGpol_dalpha2_ij*bornRadius1);

                                // Sum the forces on atom j.

                                if (tgx % 4 == 0)
483
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+1]+tempBuffer[get_local_id(0)+2]+tempBuffer[get_local_id(0)+3];
484
                                if (tgx == 0) {
485
                                    float4 sum = tempBuffer[get_local_id(0)]+tempBuffer[get_local_id(0)+4]+tempBuffer[get_local_id(0)+8]+tempBuffer[get_local_id(0)+12]+tempBuffer[get_local_id(0)+16]+tempBuffer[get_local_id(0)+20]+tempBuffer[get_local_id(0)+24]+tempBuffer[get_local_id(0)+28];
486
487
488
489
490
491
492
493
                                    localData[tbx+j].fx += sum.x;
                                    localData[tbx+j].fy += sum.y;
                                    localData[tbx+j].fz += sum.z;
                                    localData[tbx+j].fw += sum.w;
                                }
                            }
                        }
                    }
494
                }
495
496
497
498
                else
#endif
                {
                    // Compute the full set of interactions in this tile.
499

500
                    unsigned int tj = tgx;
501
                    for (unsigned int j = 0; j < TILE_SIZE; j++) {
502
503
                        if (atom1 < NUM_ATOMS && y*TILE_SIZE+tj < NUM_ATOMS) {
                            float4 posq2 = (float4) (localData[tbx+tj].x, localData[tbx+tj].y, localData[tbx+tj].z, localData[tbx+tj].q);
504
                            float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
505
#ifdef USE_PERIODIC
506
507
508
                            delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                            delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                            delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
509
510
#endif
                            float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
511
512
                            float invR = RSQRT(r2);
                            float r = RECIP(invR);
513
                            float bornRadius2 = localData[tbx+tj].bornRadius;
514
                            float alpha2_ij = bornRadius1*bornRadius2;
515
                            float D_ij = r2*RECIP(4.0f*alpha2_ij);
516
                            float expTerm = EXP(-D_ij);
517
                            float denominator2 = r2 + alpha2_ij*expTerm;
518
                            float denominator = SQRT(denominator2);
519
520
                            float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
                            float Gpol = tempEnergy*RECIP(denominator2);
521
522
523
                            float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
                            float dEdR = Gpol*(1.0f - 0.25f*expTerm);
#ifdef USE_CUTOFF
524
                            if (r2 > CUTOFF_SQUARED) {
525
                                dEdR = 0.0f;
526
                                tempEnergy  = 0.0f;
527
                                dGpol_dalpha2_ij = 0.0f;
528
                            }
529
#endif
530
                            force.w += dGpol_dalpha2_ij*bornRadius2;
531
                            energy += tempEnergy;
532
533
                            delta.xyz *= dEdR;
                            force.xyz -= delta.xyz;
534
535
536
537
                            localData[tbx+tj].fx += delta.x;
                            localData[tbx+tj].fy += delta.y;
                            localData[tbx+tj].fz += delta.z;
                            localData[tbx+tj].fw += dGpol_dalpha2_ij*bornRadius1;
538
                        }
539
                        tj = (tj + 1) & (TILE_SIZE - 1);
540
541
542
                    }
                }
            }
543
544
545
546
547
        }
        
        // Write results.  We need to coordinate between warps to make sure no two of them
        // ever try to write to the same piece of memory at the same time.
        
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
#ifdef SUPPORTS_64_BIT_ATOMICS
        if (pos < end) {
            const unsigned int offset = x*TILE_SIZE + tgx;
            atom_add(&forceBuffers[offset], (long) (force.x*0xFFFFFFFF));
            atom_add(&forceBuffers[offset+PADDED_NUM_ATOMS], (long) (force.y*0xFFFFFFFF));
            atom_add(&forceBuffers[offset+2*PADDED_NUM_ATOMS], (long) (force.z*0xFFFFFFFF));
            atom_add(&global_bornForce[offset], (long) (force.w*0xFFFFFFFF));
        }
        if (pos < end && x != y) {
            const unsigned int offset = y*TILE_SIZE + tgx;
            atom_add(&forceBuffers[offset], (long) (localData[get_local_id(0)].fx*0xFFFFFFFF));
            atom_add(&forceBuffers[offset+PADDED_NUM_ATOMS], (long) (localData[get_local_id(0)].fy*0xFFFFFFFF));
            atom_add(&forceBuffers[offset+2*PADDED_NUM_ATOMS], (long) (localData[get_local_id(0)].fz*0xFFFFFFFF));
            atom_add(&global_bornForce[offset], (long) (localData[get_local_id(0)].fw*0xFFFFFFFF));
        }
#else
564
565
566
567
568
569
570
571
572
        int writeX = (pos < end ? x : -1);
        int writeY = (pos < end && x != y ? y : -1);
        if (tgx == 0)
            reservedBlocks[localGroupIndex] = (int2)(writeX, writeY);
        bool done = false;
        int doneIndex = 0;
        int checkIndex = 0;
        while (true) {
            // See if any warp still needs to write its data.
573

574
575
576
577
578
579
580
            bool allDone = true;
            barrier(CLK_LOCAL_MEM_FENCE);
            while (doneIndex < WARPS_PER_GROUP && allDone) {
                if (reservedBlocks[doneIndex].x != -1)
                    allDone = false;
                else
                    doneIndex++;
581
            }
582
583
584
585
586
            if (allDone)
                break;
            if (!done) {
                // See whether this warp can write its data.  This requires that no previous warp
                // is trying to write to the same block of the buffer.
587

588
589
590
591
592
593
594
595
596
597
                bool canWrite = (writeX != -1);
                while (checkIndex < localGroupIndex && canWrite) {
                    if ((reservedBlocks[checkIndex].x == x || reservedBlocks[checkIndex].y == x) ||
                            (writeY != -1 && (reservedBlocks[checkIndex].x == y || reservedBlocks[checkIndex].y == y)))
                        canWrite = false;
                    else
                        checkIndex++;
                }
                if (canWrite) {
                    // Write the data to global memory, then mark this warp as done.
598

599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
                    if (writeX > -1) {
                        const unsigned int offset = x*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        forceBuffers[offset].xyz += force.xyz;
                        global_bornForce[offset] += force.w;
                    }
                    if (writeY > -1) {
                        const unsigned int offset = y*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        forceBuffers[offset] += (float4) (localData[get_local_id(0)].fx, localData[get_local_id(0)].fy, localData[get_local_id(0)].fz, 0.0f);
                        global_bornForce[offset] += localData[get_local_id(0)].fw;
                    }
                    done = true;
                    if (tgx == 0)
                        reservedBlocks[localGroupIndex] = (int2)(-1, -1);
                }
            }
614
        }
615
#endif
616
        lasty = y;
617
        pos++;
618
    } while (pos < end);
619
620
    energyBuffer[get_global_id(0)] += energy;
}