"vscode:/vscode.git/clone" did not exist on "e2c80f2e19e299e1d463b87e08a34af1b25755d9"
gbsaObc_nvidia.cl 29.7 KB
Newer Older
1
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
2
#define TILE_SIZE 32
3

4
5
6
7
8
9
10
11
12
13
typedef struct {
    float x, y, z;
    float q;
    float fx, fy, fz, fw;
    float radius, scaledRadius;
    float bornSum;
    float bornRadius;
    float bornForce;
} AtomData;

14
15
16
17
/**
 * Compute the Born sum.
 */
__kernel void computeBornSum(__global float* global_bornSum, __global float4* posq, __global float2* global_params,
18
        __local AtomData* localData, __local float* tempBuffer,
19
#ifdef USE_CUTOFF
20
        __global ushort2* tiles, __global unsigned int* interactionCount, float4 periodicBoxSize, float4 invPeriodicBoxSize, unsigned int maxTiles, __global unsigned int* interactionFlags) {
21
22
23
#else
        unsigned int numTiles) {
#endif
24
25
    unsigned int totalWarps = get_global_size(0)/TILE_SIZE;
    unsigned int warp = get_global_id(0)/TILE_SIZE;
26
27
#ifdef USE_CUTOFF
    unsigned int numTiles = interactionCount[0];
28
29
    unsigned int pos = warp*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
    unsigned int end = (warp+1)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
30
#else
31
32
    unsigned int pos = warp*numTiles/totalWarps;
    unsigned int end = (warp+1)*numTiles/totalWarps;
33
#endif
34
    unsigned int lasty = 0xFFFFFFFF;
35
36
37
    __local int2 reservedBlocks[WARPS_PER_GROUP];
    
    do {
38
        // Extract the coordinates of this tile
39
40
41
        const unsigned int tgx = get_local_id(0) & (TILE_SIZE-1);
        const unsigned int tbx = get_local_id(0) - tgx;
        const unsigned int localGroupIndex = get_local_id(0)/TILE_SIZE;
42
        unsigned int x, y;
43
44
        float bornSum = 0.0f;
        if (pos < end) {
45
#ifdef USE_CUTOFF
46
47
48
49
50
51
            if (numTiles <= maxTiles) {
                ushort2 tileIndices = tiles[pos];
                x = tileIndices.x;
                y = tileIndices.y;
            }
            else
52
#endif
53
54
            {
                y = (unsigned int) floor(NUM_BLOCKS+0.5f-sqrt((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
55
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
56
57
58
59
                if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                    y += (x < y ? -1 : 1);
                    x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
                }
60
            }
61
62
63
64
65
66
67
            unsigned int atom1 = x*TILE_SIZE + tgx;
            float4 posq1 = posq[atom1];
            float2 params1 = global_params[atom1];
            if (pos >= end)
                ; // This warp is done.
            else if (x == y) {
                // This tile is on the diagonal.
68

69
70
71
72
73
74
75
76
                localData[get_local_id(0)].x = posq1.x;
                localData[get_local_id(0)].y = posq1.y;
                localData[get_local_id(0)].z = posq1.z;
                localData[get_local_id(0)].q = posq1.w;
                localData[get_local_id(0)].radius = params1.x;
                localData[get_local_id(0)].scaledRadius = params1.y;
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
                    float4 delta = (float4) (localData[tbx+j].x-posq1.x, localData[tbx+j].y-posq1.y, localData[tbx+j].z-posq1.z, 0.0f);
77
#ifdef USE_PERIODIC
78
79
80
                    delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                    delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                    delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
81
#endif
82
                    float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
83
#ifdef USE_CUTOFF
84
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
85
#else
86
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
87
#endif
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
                        float invR = RSQRT(r2);
                        float r = RECIP(invR);
                        float2 params2 = (float2) (localData[tbx+j].radius, localData[tbx+j].scaledRadius);
                        float rScaledRadiusJ = r+params2.y;
                        if ((j != tgx) && (params1.x < rScaledRadiusJ)) {
                            float l_ij = RECIP(max(params1.x, fabs(r-params2.y)));
                            float u_ij = RECIP(rScaledRadiusJ);
                            float l_ij2 = l_ij*l_ij;
                            float u_ij2 = u_ij*u_ij;
                            float ratio = LOG(u_ij * RECIP(l_ij));
                            bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
                                             (0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2);
                            if (params1.x < params2.x-r)
                                bornSum += 2.0f*(RECIP(params1.x)-l_ij);
                        }
103
104
105
                    }
                }
            }
106
107
            else {
                // This is an off-diagonal tile.
108

109
110
111
112
113
114
115
116
117
118
                if (lasty != y) {
                    unsigned int j = y*TILE_SIZE + tgx;
                    float4 tempPosq = posq[j];
                    localData[get_local_id(0)].x = tempPosq.x;
                    localData[get_local_id(0)].y = tempPosq.y;
                    localData[get_local_id(0)].z = tempPosq.z;
                    localData[get_local_id(0)].q = tempPosq.w;
                    float2 tempParams = global_params[j];
                    localData[get_local_id(0)].radius = tempParams.x;
                    localData[get_local_id(0)].scaledRadius = tempParams.y;
119
                }
120
121
122
123
124
125
126
127
128
                localData[get_local_id(0)].bornSum = 0.0f;
#ifdef USE_CUTOFF
                unsigned int flags = (numTiles <= maxTiles ? interactionFlags[pos] : 0xFFFFFFFF);
                if (flags != 0xFFFFFFFF && false) { // TODO: Fix this: should be checking for exclusions
                    if (flags == 0) {
                        // No interactions in this tile.
                    }
                    else {
                        // Compute only a subset of the interactions in this tile.
129

130
131
132
                        for (unsigned int j = 0; j < TILE_SIZE; j++) {
                            if ((flags&(1<<j)) != 0) {
                                float4 delta = (float4) (localData[tbx+j].x-posq1.x, localData[tbx+j].y-posq1.y, localData[tbx+j].z-posq1.z, 0.0f);
133
#ifdef USE_PERIODIC
134
135
136
                                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
137
#endif
138
139
                                float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                                tempBuffer[get_local_id(0)] = 0.0f;
140
#ifdef USE_CUTOFF
141
                                if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
142
#else
143
                                if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
144
#endif
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
                                    float invR = RSQRT(r2);
                                    float r = RECIP(invR);
                                    float2 params2 = (float2) (localData[tbx+j].radius, localData[tbx+j].scaledRadius);
                                    float rScaledRadiusJ = r+params2.y;
                                    if (params1.x < rScaledRadiusJ) {
                                        float l_ij = RECIP(max(params1.x, fabs(r-params2.y)));
                                        float u_ij = RECIP(rScaledRadiusJ);
                                        float l_ij2 = l_ij*l_ij;
                                        float u_ij2 = u_ij*u_ij;
                                        float ratio = LOG(u_ij * RECIP(l_ij));
                                        bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
                                                         (0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2);
                                        if (params1.x < params2.x-r)
                                            bornSum += 2.0f*(RECIP(params1.x)-l_ij);
                                    }
                                    float rScaledRadiusI = r+params1.y;
                                    if (params2.x < rScaledRadiusI) {
                                        float l_ij = RECIP(max(params2.x, fabs(r-params1.y)));
                                        float u_ij = RECIP(rScaledRadiusI);
                                        float l_ij2 = l_ij*l_ij;
                                        float u_ij2 = u_ij*u_ij;
                                        float ratio = LOG(u_ij * RECIP(l_ij));
                                        float term = l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
                                                         (0.25f*params1.y*params1.y*invR)*(l_ij2-u_ij2);
                                        if (params2.x < params1.x-r)
                                            term += 2.0f*(RECIP(params2.x)-l_ij);
                                        tempBuffer[get_local_id(0)] = term;
                                    }
173
174
                                }

175
                                // Sum the forces on atom j.
176

177
178
179
180
181
182
183
184
185
186
187
                                if (tgx % 2 == 0)
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+1];
                                if (tgx % 4 == 0)
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+2];
                                if (tgx % 8 == 0)
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+4];
                                if (tgx % 16 == 0)
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+8];
                                if (tgx == 0)
                                    localData[tbx+j].bornSum += tempBuffer[get_local_id(0)] + tempBuffer[get_local_id(0)+16];
                            }
188
189
190
                        }
                    }
                }
191
                else
192
#endif
193
194
                {
                    // Compute the full set of interactions in this tile.
195

196
197
198
                    unsigned int tj = tgx;
                    for (unsigned int j = 0; j < TILE_SIZE; j++) {
                        float4 delta = (float4) (localData[tbx+tj].x-posq1.x, localData[tbx+tj].y-posq1.y, localData[tbx+tj].z-posq1.z, 0.0f);
199
#ifdef USE_PERIODIC
200
201
202
                        delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                        delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                        delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
203
#endif
204
                        float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
205
#ifdef USE_CUTOFF
206
                        if (atom1 < NUM_ATOMS && y*TILE_SIZE+tj < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
207
#else
208
                        if (atom1 < NUM_ATOMS && y*TILE_SIZE+tj < NUM_ATOMS) {
209
#endif
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
                            float invR = RSQRT(r2);
                            float r = RECIP(invR);
                            float2 params2 = (float2) (localData[tbx+tj].radius, localData[tbx+tj].scaledRadius);
                            float rScaledRadiusJ = r+params2.y;
                            if (params1.x < rScaledRadiusJ) {
                                float l_ij = RECIP(max(params1.x, fabs(r-params2.y)));
                                float u_ij = RECIP(rScaledRadiusJ);
                                float l_ij2 = l_ij*l_ij;
                                float u_ij2 = u_ij*u_ij;
                                float ratio = LOG(u_ij * RECIP(l_ij));
                                bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
                                                 (0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2);
                                if (params1.x < params2.x-r)
                                    bornSum += 2.0f*(RECIP(params1.x)-l_ij);
                            }
                            float rScaledRadiusI = r+params1.y;
                            if (params2.x < rScaledRadiusI) {
                                float l_ij = RECIP(max(params2.x, fabs(r-params1.y)));
                                float u_ij = RECIP(rScaledRadiusI);
                                float l_ij2 = l_ij*l_ij;
                                float u_ij2 = u_ij*u_ij;
                                float ratio = LOG(u_ij * RECIP(l_ij));
                                float term = l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
                                                 (0.25f*params1.y*params1.y*invR)*(l_ij2-u_ij2);
                                if (params2.x < params1.x-r)
                                    term += 2.0f*(RECIP(params2.x)-l_ij);
                                localData[tbx+tj].bornSum += term;
                            }
238
                        }
239
                        tj = (tj + 1) & (TILE_SIZE - 1);
240
241
242
                    }
                }
            }
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        }
        
        // Write results.  We need to coordinate between warps to make sure no two of them
        // ever try to write to the same piece of memory at the same time.
        
        int writeX = (pos < end ? x : -1);
        int writeY = (pos < end && x != y ? y : -1);
        if (tgx == 0)
            reservedBlocks[localGroupIndex] = (int2)(writeX, writeY);
        bool done = false;
        int doneIndex = 0;
        int checkIndex = 0;
        while (true) {
            // See if any warp still needs to write its data.

            bool allDone = true;
            barrier(CLK_LOCAL_MEM_FENCE);
            while (doneIndex < WARPS_PER_GROUP && allDone) {
                if (reservedBlocks[doneIndex].x != -1)
                    allDone = false;
                else
                    doneIndex++;
            }
            if (allDone)
                break;
            if (!done) {
                // See whether this warp can write its data.  This requires that no previous warp
                // is trying to write to the same block of the buffer.
271

272
273
274
275
276
277
278
279
280
281
                bool canWrite = (writeX != -1);
                while (checkIndex < localGroupIndex && canWrite) {
                    if ((reservedBlocks[checkIndex].x == x || reservedBlocks[checkIndex].y == x) ||
                            (writeY != -1 && (reservedBlocks[checkIndex].x == y || reservedBlocks[checkIndex].y == y)))
                        canWrite = false;
                    else
                        checkIndex++;
                }
                if (canWrite) {
                    // Write the data to global memory, then mark this warp as done.
282

283
284
285
286
287
288
289
290
291
292
293
294
295
                    if (writeX > -1) {
                        const unsigned int offset = x*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        global_bornSum[offset] += bornSum;
                    }
                    if (writeY > -1) {
                        const unsigned int offset = y*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        global_bornSum[offset] += localData[get_local_id(0)].bornSum;
                    }
                    done = true;
                    if (tgx == 0)
                        reservedBlocks[localGroupIndex] = (int2)(-1, -1);
                }
            }
296
        }
297
        lasty = y;
298
        pos++;
299
    } while (pos < end);
300
301
}

302
303
304
305
/**
 * First part of computing the GBSA interaction.
 */

306
307
__kernel void computeGBSAForce1(__global float4* forceBuffers, __global float* energyBuffer,
        __global float4* posq, __global float* global_bornRadii, __global float* global_bornForce,
308
        __local AtomData* localData, __local float4* tempBuffer,
309
#ifdef USE_CUTOFF
310
        __global ushort2* tiles, __global unsigned int* interactionCount, float4 periodicBoxSize, float4 invPeriodicBoxSize, unsigned int maxTiles, __global unsigned int* interactionFlags) {
311
312
313
#else
        unsigned int numTiles) {
#endif
314
315
    unsigned int totalWarps = get_global_size(0)/TILE_SIZE;
    unsigned int warp = get_global_id(0)/TILE_SIZE;
316
317
#ifdef USE_CUTOFF
    unsigned int numTiles = interactionCount[0];
318
319
    unsigned int pos = warp*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
    unsigned int end = (warp+1)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
320
#else
321
322
    unsigned int pos = warp*numTiles/totalWarps;
    unsigned int end = (warp+1)*numTiles/totalWarps;
323
#endif
324
325
    float energy = 0.0f;
    unsigned int lasty = 0xFFFFFFFF;
326
327
328
    __local int2 reservedBlocks[WARPS_PER_GROUP];
    
    do {
329
        // Extract the coordinates of this tile
330
331
332
        const unsigned int tgx = get_local_id(0) & (TILE_SIZE-1);
        const unsigned int tbx = get_local_id(0) - tgx;
        const unsigned int localGroupIndex = get_local_id(0)/TILE_SIZE;
333
        unsigned int x, y;
334
335
        float4 force = 0.0f;
        if (pos < end) {
336
#ifdef USE_CUTOFF
337
338
339
340
341
342
            if (numTiles <= maxTiles) {
                ushort2 tileIndices = tiles[pos];
                x = tileIndices.x;
                y = tileIndices.y;
            }
            else
343
#endif
344
345
            {
                y = (unsigned int) floor(NUM_BLOCKS+0.5f-sqrt((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
346
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
347
348
349
350
                if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                    y += (x < y ? -1 : 1);
                    x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
                }
351
            }
352
353
354
355
356
            unsigned int atom1 = x*TILE_SIZE + tgx;
            float4 posq1 = posq[atom1];
            float bornRadius1 = global_bornRadii[atom1];
            if (x == y) {
                // This tile is on the diagonal.
357

358
359
360
361
362
363
364
365
366
                localData[get_local_id(0)].x = posq1.x;
                localData[get_local_id(0)].y = posq1.y;
                localData[get_local_id(0)].z = posq1.z;
                localData[get_local_id(0)].q = posq1.w;
                localData[get_local_id(0)].bornRadius = bornRadius1;
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
                        float4 posq2 = (float4) (localData[tbx+j].x, localData[tbx+j].y, localData[tbx+j].z, localData[tbx+j].q);
                        float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
367
#ifdef USE_PERIODIC
368
369
370
                        delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                        delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                        delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
371
#endif
372
373
374
375
376
377
378
379
380
381
382
383
384
                        float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                        float invR = RSQRT(r2);
                        float r = RECIP(invR);
                        float bornRadius2 = localData[tbx+j].bornRadius;
                        float alpha2_ij = bornRadius1*bornRadius2;
                        float D_ij = r2*RECIP(4.0f*alpha2_ij);
                        float expTerm = EXP(-D_ij);
                        float denominator2 = r2 + alpha2_ij*expTerm;
                        float denominator = SQRT(denominator2);
                        float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
                        float Gpol = tempEnergy*RECIP(denominator2);
                        float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
                        float dEdR = Gpol*(1.0f - 0.25f*expTerm);
385
#ifdef USE_CUTOFF
386
387
388
389
390
                        if (r2 > CUTOFF_SQUARED) {
                            dEdR = 0.0f;
                            tempEnergy  = 0.0f;
                            dGpol_dalpha2_ij = 0.0f;
                        }
391
#endif
392
393
394
395
396
                        force.w += dGpol_dalpha2_ij*bornRadius2;
                        energy += 0.5f*tempEnergy;
                        delta.xyz *= dEdR;
                        force.xyz -= delta.xyz;
                    }
397
398
                }
            }
399
400
            else {
                // This is an off-diagonal tile.
401

402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
                if (lasty != y) {
                    unsigned int j = y*TILE_SIZE + tgx;
                    float4 tempPosq = posq[j];
                    localData[get_local_id(0)].x = tempPosq.x;
                    localData[get_local_id(0)].y = tempPosq.y;
                    localData[get_local_id(0)].z = tempPosq.z;
                    localData[get_local_id(0)].q = tempPosq.w;
                    localData[get_local_id(0)].bornRadius = global_bornRadii[j];
                }
                localData[get_local_id(0)].fx = 0.0f;
                localData[get_local_id(0)].fy = 0.0f;
                localData[get_local_id(0)].fz = 0.0f;
                localData[get_local_id(0)].fw = 0.0f;
#ifdef USE_CUTOFF
                unsigned int flags = (numTiles <= maxTiles ? interactionFlags[pos] : 0xFFFFFFFF);
                if (flags != 0xFFFFFFFF && false) { // TODO: Fix this: should be checking for exclusions
                    if (flags == 0) {
                        // No interactions in this tile.
                    }
                    else {
                        // Compute only a subset of the interactions in this tile.
423

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
                        for (unsigned int j = 0; j < TILE_SIZE; j++) {
                            if ((flags&(1<<j)) != 0) {
                                float4 posq2 = (float4) (localData[tbx+j].x, localData[tbx+j].y, localData[tbx+j].z, localData[tbx+j].q);
                                float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
#ifdef USE_PERIODIC
                                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
                                float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                                float invR = RSQRT(r2);
                                float r = RECIP(invR);
                                float bornRadius2 = localData[tbx+j].bornRadius;
                                float alpha2_ij = bornRadius1*bornRadius2;
                                float D_ij = r2*RECIP(4.0f*alpha2_ij);
                                float expTerm = EXP(-D_ij);
                                float denominator2 = r2 + alpha2_ij*expTerm;
                                float denominator = SQRT(denominator2);
                                float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
                                float Gpol = tempEnergy*RECIP(denominator2);
                                float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
                                float dEdR = Gpol*(1.0f - 0.25f*expTerm);
446
#ifdef USE_CUTOFF
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
                                if (atom1 >= NUM_ATOMS || y*TILE_SIZE+j >= NUM_ATOMS || r2 > CUTOFF_SQUARED) {
#else
                                if (atom1 >= NUM_ATOMS || y*TILE_SIZE+j >= NUM_ATOMS) {
#endif
                                    dEdR = 0.0f;
                                    dGpol_dalpha2_ij = 0.0f;
                                    tempEnergy = 0.0f;
                                }
                                energy += tempEnergy;
                                force.w += dGpol_dalpha2_ij*bornRadius2;
                                delta.xyz *= dEdR;
                                force.xyz -= delta.xyz;
                                tempBuffer[get_local_id(0)] = (float4) (delta.xyz, dGpol_dalpha2_ij*bornRadius1);

                                // Sum the forces on atom j.

                                if (tgx % 2 == 0)
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+1];
                                if (tgx % 4 == 0)
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+2];
                                if (tgx % 8 == 0)
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+4];
                                if (tgx % 16 == 0)
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+8];
                                if (tgx == 0) {
                                    float4 sum = tempBuffer[get_local_id(0)] + tempBuffer[get_local_id(0)+16];
                                    localData[tbx+j].fx += sum.x;
                                    localData[tbx+j].fy += sum.y;
                                    localData[tbx+j].fz += sum.z;
                                    localData[tbx+j].fw += sum.w;
                                }
                            }
                        }
                    }
481
                }
482
483
484
485
                else
#endif
                {
                    // Compute the full set of interactions in this tile.
486

487
                    unsigned int tj = tgx;
488
                    for (unsigned int j = 0; j < TILE_SIZE; j++) {
489
490
                        if (atom1 < NUM_ATOMS && y*TILE_SIZE+tj < NUM_ATOMS) {
                            float4 posq2 = (float4) (localData[tbx+tj].x, localData[tbx+tj].y, localData[tbx+tj].z, localData[tbx+tj].q);
491
                            float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
492
#ifdef USE_PERIODIC
493
494
495
                            delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                            delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                            delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
496
497
#endif
                            float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
498
499
                            float invR = RSQRT(r2);
                            float r = RECIP(invR);
500
                            float bornRadius2 = localData[tbx+tj].bornRadius;
501
                            float alpha2_ij = bornRadius1*bornRadius2;
502
                            float D_ij = r2*RECIP(4.0f*alpha2_ij);
503
                            float expTerm = EXP(-D_ij);
504
                            float denominator2 = r2 + alpha2_ij*expTerm;
505
                            float denominator = SQRT(denominator2);
506
507
                            float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
                            float Gpol = tempEnergy*RECIP(denominator2);
508
509
510
                            float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
                            float dEdR = Gpol*(1.0f - 0.25f*expTerm);
#ifdef USE_CUTOFF
511
                            if (r2 > CUTOFF_SQUARED) {
512
                                dEdR = 0.0f;
513
                                tempEnergy  = 0.0f;
514
                                dGpol_dalpha2_ij = 0.0f;
515
                            }
516
#endif
517
                            force.w += dGpol_dalpha2_ij*bornRadius2;
518
                            energy += tempEnergy;
519
520
                            delta.xyz *= dEdR;
                            force.xyz -= delta.xyz;
521
522
523
524
                            localData[tbx+tj].fx += delta.x;
                            localData[tbx+tj].fy += delta.y;
                            localData[tbx+tj].fz += delta.z;
                            localData[tbx+tj].fw += dGpol_dalpha2_ij*bornRadius1;
525
                        }
526
                        tj = (tj + 1) & (TILE_SIZE - 1);
527
528
529
                    }
                }
            }
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        }
        
        // Write results.  We need to coordinate between warps to make sure no two of them
        // ever try to write to the same piece of memory at the same time.
        
        int writeX = (pos < end ? x : -1);
        int writeY = (pos < end && x != y ? y : -1);
        if (tgx == 0)
            reservedBlocks[localGroupIndex] = (int2)(writeX, writeY);
        bool done = false;
        int doneIndex = 0;
        int checkIndex = 0;
        while (true) {
            // See if any warp still needs to write its data.
544

545
546
547
548
549
550
551
            bool allDone = true;
            barrier(CLK_LOCAL_MEM_FENCE);
            while (doneIndex < WARPS_PER_GROUP && allDone) {
                if (reservedBlocks[doneIndex].x != -1)
                    allDone = false;
                else
                    doneIndex++;
552
            }
553
554
555
556
557
            if (allDone)
                break;
            if (!done) {
                // See whether this warp can write its data.  This requires that no previous warp
                // is trying to write to the same block of the buffer.
558

559
560
561
562
563
564
565
566
567
568
                bool canWrite = (writeX != -1);
                while (checkIndex < localGroupIndex && canWrite) {
                    if ((reservedBlocks[checkIndex].x == x || reservedBlocks[checkIndex].y == x) ||
                            (writeY != -1 && (reservedBlocks[checkIndex].x == y || reservedBlocks[checkIndex].y == y)))
                        canWrite = false;
                    else
                        checkIndex++;
                }
                if (canWrite) {
                    // Write the data to global memory, then mark this warp as done.
569

570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
                    if (writeX > -1) {
                        const unsigned int offset = x*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        forceBuffers[offset].xyz += force.xyz;
                        global_bornForce[offset] += force.w;
                    }
                    if (writeY > -1) {
                        const unsigned int offset = y*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        forceBuffers[offset] += (float4) (localData[get_local_id(0)].fx, localData[get_local_id(0)].fy, localData[get_local_id(0)].fz, 0.0f);
                        global_bornForce[offset] += localData[get_local_id(0)].fw;
                    }
                    done = true;
                    if (tgx == 0)
                        reservedBlocks[localGroupIndex] = (int2)(-1, -1);
                }
            }
585
        }
586
        lasty = y;
587
        pos++;
588
    } while (pos < end);
589
590
    energyBuffer[get_global_id(0)] += energy;
}