gbsaObc_nvidia.cl 31.1 KB
Newer Older
1
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
2
3
4
#ifdef SUPPORTS_64_BIT_ATOMICS
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
#endif
5
#define TILE_SIZE 32
6

7
8
9
10
11
typedef struct {
    float x, y, z;
    float q;
    float radius, scaledRadius;
    float bornSum;
Peter Eastman's avatar
Peter Eastman committed
12
} AtomData1;
13

14
15
16
/**
 * Compute the Born sum.
 */
17
18
19
20
21
22
23
__kernel void computeBornSum(
#ifdef SUPPORTS_64_BIT_ATOMICS
        __global long* global_bornSum,
#else
        __global float* global_bornSum,
#endif
        __global float4* posq, __global float2* global_params,
Peter Eastman's avatar
Peter Eastman committed
24
        __local AtomData1* localData, __local float* tempBuffer,
25
#ifdef USE_CUTOFF
26
        __global ushort2* tiles, __global unsigned int* interactionCount, float4 periodicBoxSize, float4 invPeriodicBoxSize, unsigned int maxTiles, __global unsigned int* interactionFlags) {
27
28
29
#else
        unsigned int numTiles) {
#endif
30
31
    unsigned int totalWarps = get_global_size(0)/TILE_SIZE;
    unsigned int warp = get_global_id(0)/TILE_SIZE;
32
33
#ifdef USE_CUTOFF
    unsigned int numTiles = interactionCount[0];
34
35
    unsigned int pos = warp*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
    unsigned int end = (warp+1)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
36
#else
37
38
    unsigned int pos = warp*numTiles/totalWarps;
    unsigned int end = (warp+1)*numTiles/totalWarps;
39
#endif
40
    unsigned int lasty = 0xFFFFFFFF;
41
42
43
    __local int2 reservedBlocks[WARPS_PER_GROUP];
    
    do {
44
        // Extract the coordinates of this tile
45
46
47
        const unsigned int tgx = get_local_id(0) & (TILE_SIZE-1);
        const unsigned int tbx = get_local_id(0) - tgx;
        const unsigned int localGroupIndex = get_local_id(0)/TILE_SIZE;
48
        unsigned int x, y;
49
50
        float bornSum = 0.0f;
        if (pos < end) {
51
#ifdef USE_CUTOFF
52
53
54
55
56
57
            if (numTiles <= maxTiles) {
                ushort2 tileIndices = tiles[pos];
                x = tileIndices.x;
                y = tileIndices.y;
            }
            else
58
#endif
59
60
            {
                y = (unsigned int) floor(NUM_BLOCKS+0.5f-sqrt((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
61
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
62
63
64
65
                if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                    y += (x < y ? -1 : 1);
                    x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
                }
66
            }
67
68
69
70
71
72
73
            unsigned int atom1 = x*TILE_SIZE + tgx;
            float4 posq1 = posq[atom1];
            float2 params1 = global_params[atom1];
            if (pos >= end)
                ; // This warp is done.
            else if (x == y) {
                // This tile is on the diagonal.
74

75
76
77
78
79
80
81
82
                localData[get_local_id(0)].x = posq1.x;
                localData[get_local_id(0)].y = posq1.y;
                localData[get_local_id(0)].z = posq1.z;
                localData[get_local_id(0)].q = posq1.w;
                localData[get_local_id(0)].radius = params1.x;
                localData[get_local_id(0)].scaledRadius = params1.y;
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
                    float4 delta = (float4) (localData[tbx+j].x-posq1.x, localData[tbx+j].y-posq1.y, localData[tbx+j].z-posq1.z, 0.0f);
83
#ifdef USE_PERIODIC
84
85
86
                    delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                    delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                    delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
87
#endif
88
                    float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
89
#ifdef USE_CUTOFF
90
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
91
#else
92
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
93
#endif
94
95
96
97
98
99
100
101
102
103
                        float invR = RSQRT(r2);
                        float r = RECIP(invR);
                        float2 params2 = (float2) (localData[tbx+j].radius, localData[tbx+j].scaledRadius);
                        float rScaledRadiusJ = r+params2.y;
                        if ((j != tgx) && (params1.x < rScaledRadiusJ)) {
                            float l_ij = RECIP(max(params1.x, fabs(r-params2.y)));
                            float u_ij = RECIP(rScaledRadiusJ);
                            float l_ij2 = l_ij*l_ij;
                            float u_ij2 = u_ij*u_ij;
                            float ratio = LOG(u_ij * RECIP(l_ij));
Peter Eastman's avatar
Peter Eastman committed
104
105
                            bornSum += l_ij - u_ij + (0.50f*invR*ratio) + 0.25f*(r*(u_ij2-l_ij2) +
                                             (params2.y*params2.y*invR)*(l_ij2-u_ij2));
106
107
108
                            if (params1.x < params2.x-r)
                                bornSum += 2.0f*(RECIP(params1.x)-l_ij);
                        }
109
110
111
                    }
                }
            }
112
113
            else {
                // This is an off-diagonal tile.
114

115
116
117
118
119
120
121
122
123
124
                if (lasty != y) {
                    unsigned int j = y*TILE_SIZE + tgx;
                    float4 tempPosq = posq[j];
                    localData[get_local_id(0)].x = tempPosq.x;
                    localData[get_local_id(0)].y = tempPosq.y;
                    localData[get_local_id(0)].z = tempPosq.z;
                    localData[get_local_id(0)].q = tempPosq.w;
                    float2 tempParams = global_params[j];
                    localData[get_local_id(0)].radius = tempParams.x;
                    localData[get_local_id(0)].scaledRadius = tempParams.y;
125
                }
126
127
128
129
130
131
132
133
134
                localData[get_local_id(0)].bornSum = 0.0f;
#ifdef USE_CUTOFF
                unsigned int flags = (numTiles <= maxTiles ? interactionFlags[pos] : 0xFFFFFFFF);
                if (flags != 0xFFFFFFFF && false) { // TODO: Fix this: should be checking for exclusions
                    if (flags == 0) {
                        // No interactions in this tile.
                    }
                    else {
                        // Compute only a subset of the interactions in this tile.
135

136
137
138
                        for (unsigned int j = 0; j < TILE_SIZE; j++) {
                            if ((flags&(1<<j)) != 0) {
                                float4 delta = (float4) (localData[tbx+j].x-posq1.x, localData[tbx+j].y-posq1.y, localData[tbx+j].z-posq1.z, 0.0f);
139
#ifdef USE_PERIODIC
140
141
142
                                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
143
#endif
144
145
                                float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                                tempBuffer[get_local_id(0)] = 0.0f;
146
#ifdef USE_CUTOFF
147
                                if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
148
#else
149
                                if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
150
#endif
151
152
153
154
155
156
157
158
159
160
                                    float invR = RSQRT(r2);
                                    float r = RECIP(invR);
                                    float2 params2 = (float2) (localData[tbx+j].radius, localData[tbx+j].scaledRadius);
                                    float rScaledRadiusJ = r+params2.y;
                                    if (params1.x < rScaledRadiusJ) {
                                        float l_ij = RECIP(max(params1.x, fabs(r-params2.y)));
                                        float u_ij = RECIP(rScaledRadiusJ);
                                        float l_ij2 = l_ij*l_ij;
                                        float u_ij2 = u_ij*u_ij;
                                        float ratio = LOG(u_ij * RECIP(l_ij));
Peter Eastman's avatar
Peter Eastman committed
161
162
                                        bornSum += l_ij - u_ij + (0.50f*invR*ratio) + 0.25f*(r*(u_ij2-l_ij2) +
                                                         (params2.y*params2.y*invR)*(l_ij2-u_ij2));
163
164
165
166
167
168
169
170
171
172
                                        if (params1.x < params2.x-r)
                                            bornSum += 2.0f*(RECIP(params1.x)-l_ij);
                                    }
                                    float rScaledRadiusI = r+params1.y;
                                    if (params2.x < rScaledRadiusI) {
                                        float l_ij = RECIP(max(params2.x, fabs(r-params1.y)));
                                        float u_ij = RECIP(rScaledRadiusI);
                                        float l_ij2 = l_ij*l_ij;
                                        float u_ij2 = u_ij*u_ij;
                                        float ratio = LOG(u_ij * RECIP(l_ij));
Peter Eastman's avatar
Peter Eastman committed
173
174
                                        float term = l_ij - u_ij + (0.50f*invR*ratio) + 0.25f*(r*(u_ij2-l_ij2) +
                                                         (params1.y*params1.y*invR)*(l_ij2-u_ij2));
175
176
177
178
                                        if (params2.x < params1.x-r)
                                            term += 2.0f*(RECIP(params2.x)-l_ij);
                                        tempBuffer[get_local_id(0)] = term;
                                    }
179
180
                                }

181
                                // Sum the forces on atom j.
182

183
                                if (tgx % 4 == 0)
184
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+1]+tempBuffer[get_local_id(0)+2]+tempBuffer[get_local_id(0)+3];
185
                                if (tgx == 0)
186
                                    localData[tbx+j].bornSum += tempBuffer[get_local_id(0)]+tempBuffer[get_local_id(0)+4]+tempBuffer[get_local_id(0)+8]+tempBuffer[get_local_id(0)+12]+tempBuffer[get_local_id(0)+16]+tempBuffer[get_local_id(0)+20]+tempBuffer[get_local_id(0)+24]+tempBuffer[get_local_id(0)+28];
187
                            }
188
189
190
                        }
                    }
                }
191
                else
192
#endif
193
194
                {
                    // Compute the full set of interactions in this tile.
195

196
197
198
                    unsigned int tj = tgx;
                    for (unsigned int j = 0; j < TILE_SIZE; j++) {
                        float4 delta = (float4) (localData[tbx+tj].x-posq1.x, localData[tbx+tj].y-posq1.y, localData[tbx+tj].z-posq1.z, 0.0f);
199
#ifdef USE_PERIODIC
200
201
202
                        delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                        delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                        delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
203
#endif
204
                        float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
205
#ifdef USE_CUTOFF
206
                        if (atom1 < NUM_ATOMS && y*TILE_SIZE+tj < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
207
#else
208
                        if (atom1 < NUM_ATOMS && y*TILE_SIZE+tj < NUM_ATOMS) {
209
#endif
210
211
212
213
214
215
216
217
218
219
                            float invR = RSQRT(r2);
                            float r = RECIP(invR);
                            float2 params2 = (float2) (localData[tbx+tj].radius, localData[tbx+tj].scaledRadius);
                            float rScaledRadiusJ = r+params2.y;
                            if (params1.x < rScaledRadiusJ) {
                                float l_ij = RECIP(max(params1.x, fabs(r-params2.y)));
                                float u_ij = RECIP(rScaledRadiusJ);
                                float l_ij2 = l_ij*l_ij;
                                float u_ij2 = u_ij*u_ij;
                                float ratio = LOG(u_ij * RECIP(l_ij));
Peter Eastman's avatar
Peter Eastman committed
220
221
                                bornSum += l_ij - u_ij + (0.50f*invR*ratio) + 0.25f*(r*(u_ij2-l_ij2) +
                                                 (params2.y*params2.y*invR)*(l_ij2-u_ij2));
222
223
224
225
226
227
228
229
230
231
                                if (params1.x < params2.x-r)
                                    bornSum += 2.0f*(RECIP(params1.x)-l_ij);
                            }
                            float rScaledRadiusI = r+params1.y;
                            if (params2.x < rScaledRadiusI) {
                                float l_ij = RECIP(max(params2.x, fabs(r-params1.y)));
                                float u_ij = RECIP(rScaledRadiusI);
                                float l_ij2 = l_ij*l_ij;
                                float u_ij2 = u_ij*u_ij;
                                float ratio = LOG(u_ij * RECIP(l_ij));
Peter Eastman's avatar
Peter Eastman committed
232
233
                                float term = l_ij - u_ij + (0.50f*invR*ratio) + 0.25f*(r*(u_ij2-l_ij2) +
                                                 (params1.y*params1.y*invR)*(l_ij2-u_ij2));
234
235
236
237
                                if (params2.x < params1.x-r)
                                    term += 2.0f*(RECIP(params2.x)-l_ij);
                                localData[tbx+tj].bornSum += term;
                            }
238
                        }
239
                        tj = (tj + 1) & (TILE_SIZE - 1);
240
241
242
                    }
                }
            }
243
244
245
246
247
        }
        
        // Write results.  We need to coordinate between warps to make sure no two of them
        // ever try to write to the same piece of memory at the same time.
        
248
249
250
251
252
253
254
255
256
257
#ifdef SUPPORTS_64_BIT_ATOMICS
        if (pos < end) {
            const unsigned int offset = x*TILE_SIZE + tgx;
            atom_add(&global_bornSum[offset], (long) (bornSum*0xFFFFFFFF));
        }
        if (pos < end && x != y) {
            const unsigned int offset = y*TILE_SIZE + tgx;
            atom_add(&global_bornSum[offset], (long) (localData[get_local_id(0)].bornSum*0xFFFFFFFF));
        }
#else
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
        int writeX = (pos < end ? x : -1);
        int writeY = (pos < end && x != y ? y : -1);
        if (tgx == 0)
            reservedBlocks[localGroupIndex] = (int2)(writeX, writeY);
        bool done = false;
        int doneIndex = 0;
        int checkIndex = 0;
        while (true) {
            // See if any warp still needs to write its data.

            bool allDone = true;
            barrier(CLK_LOCAL_MEM_FENCE);
            while (doneIndex < WARPS_PER_GROUP && allDone) {
                if (reservedBlocks[doneIndex].x != -1)
                    allDone = false;
                else
                    doneIndex++;
            }
            if (allDone)
                break;
            if (!done) {
                // See whether this warp can write its data.  This requires that no previous warp
                // is trying to write to the same block of the buffer.
281

282
283
284
285
286
287
288
289
290
291
                bool canWrite = (writeX != -1);
                while (checkIndex < localGroupIndex && canWrite) {
                    if ((reservedBlocks[checkIndex].x == x || reservedBlocks[checkIndex].y == x) ||
                            (writeY != -1 && (reservedBlocks[checkIndex].x == y || reservedBlocks[checkIndex].y == y)))
                        canWrite = false;
                    else
                        checkIndex++;
                }
                if (canWrite) {
                    // Write the data to global memory, then mark this warp as done.
292

293
294
295
296
297
298
299
300
301
302
303
304
305
                    if (writeX > -1) {
                        const unsigned int offset = x*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        global_bornSum[offset] += bornSum;
                    }
                    if (writeY > -1) {
                        const unsigned int offset = y*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        global_bornSum[offset] += localData[get_local_id(0)].bornSum;
                    }
                    done = true;
                    if (tgx == 0)
                        reservedBlocks[localGroupIndex] = (int2)(-1, -1);
                }
            }
306
        }
307
#endif
308
        lasty = y;
309
        pos++;
310
    } while (pos < end);
311
312
}

Peter Eastman's avatar
Peter Eastman committed
313
314
315
316
317
318
319
typedef struct {
    float x, y, z;
    float q;
    float fx, fy, fz, fw;
    float bornRadius;
} AtomData2;

320
321
322
323
/**
 * First part of computing the GBSA interaction.
 */

324
325
326
327
328
329
330
__kernel void computeGBSAForce1(
#ifdef SUPPORTS_64_BIT_ATOMICS
        __global long* forceBuffers, __global long* global_bornForce,
#else
        __global float4* forceBuffers, __global float* global_bornForce,
#endif
        __global float* energyBuffer, __global float4* posq, __global float* global_bornRadii,
Peter Eastman's avatar
Peter Eastman committed
331
        __local AtomData2* localData, __local float4* tempBuffer,
332
#ifdef USE_CUTOFF
333
        __global ushort2* tiles, __global unsigned int* interactionCount, float4 periodicBoxSize, float4 invPeriodicBoxSize, unsigned int maxTiles, __global unsigned int* interactionFlags) {
334
335
336
#else
        unsigned int numTiles) {
#endif
337
338
    unsigned int totalWarps = get_global_size(0)/TILE_SIZE;
    unsigned int warp = get_global_id(0)/TILE_SIZE;
339
340
#ifdef USE_CUTOFF
    unsigned int numTiles = interactionCount[0];
341
342
    unsigned int pos = warp*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
    unsigned int end = (warp+1)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
343
#else
344
345
    unsigned int pos = warp*numTiles/totalWarps;
    unsigned int end = (warp+1)*numTiles/totalWarps;
346
#endif
347
348
    float energy = 0.0f;
    unsigned int lasty = 0xFFFFFFFF;
349
350
351
    __local int2 reservedBlocks[WARPS_PER_GROUP];
    
    do {
352
        // Extract the coordinates of this tile
353
354
355
        const unsigned int tgx = get_local_id(0) & (TILE_SIZE-1);
        const unsigned int tbx = get_local_id(0) - tgx;
        const unsigned int localGroupIndex = get_local_id(0)/TILE_SIZE;
356
        unsigned int x, y;
357
358
        float4 force = 0.0f;
        if (pos < end) {
359
#ifdef USE_CUTOFF
360
361
362
363
364
365
            if (numTiles <= maxTiles) {
                ushort2 tileIndices = tiles[pos];
                x = tileIndices.x;
                y = tileIndices.y;
            }
            else
366
#endif
367
368
            {
                y = (unsigned int) floor(NUM_BLOCKS+0.5f-sqrt((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
369
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
370
371
372
373
                if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                    y += (x < y ? -1 : 1);
                    x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
                }
374
            }
375
376
377
378
379
            unsigned int atom1 = x*TILE_SIZE + tgx;
            float4 posq1 = posq[atom1];
            float bornRadius1 = global_bornRadii[atom1];
            if (x == y) {
                // This tile is on the diagonal.
380

381
382
383
384
385
386
387
388
389
                localData[get_local_id(0)].x = posq1.x;
                localData[get_local_id(0)].y = posq1.y;
                localData[get_local_id(0)].z = posq1.z;
                localData[get_local_id(0)].q = posq1.w;
                localData[get_local_id(0)].bornRadius = bornRadius1;
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
                        float4 posq2 = (float4) (localData[tbx+j].x, localData[tbx+j].y, localData[tbx+j].z, localData[tbx+j].q);
                        float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
390
#ifdef USE_PERIODIC
391
392
393
                        delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                        delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                        delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
394
#endif
395
396
397
398
399
400
401
402
403
404
405
406
407
                        float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                        float invR = RSQRT(r2);
                        float r = RECIP(invR);
                        float bornRadius2 = localData[tbx+j].bornRadius;
                        float alpha2_ij = bornRadius1*bornRadius2;
                        float D_ij = r2*RECIP(4.0f*alpha2_ij);
                        float expTerm = EXP(-D_ij);
                        float denominator2 = r2 + alpha2_ij*expTerm;
                        float denominator = SQRT(denominator2);
                        float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
                        float Gpol = tempEnergy*RECIP(denominator2);
                        float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
                        float dEdR = Gpol*(1.0f - 0.25f*expTerm);
408
#ifdef USE_CUTOFF
409
410
411
412
413
                        if (r2 > CUTOFF_SQUARED) {
                            dEdR = 0.0f;
                            tempEnergy  = 0.0f;
                            dGpol_dalpha2_ij = 0.0f;
                        }
414
#endif
415
416
417
418
419
                        force.w += dGpol_dalpha2_ij*bornRadius2;
                        energy += 0.5f*tempEnergy;
                        delta.xyz *= dEdR;
                        force.xyz -= delta.xyz;
                    }
420
421
                }
            }
422
423
            else {
                // This is an off-diagonal tile.
424

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
                if (lasty != y) {
                    unsigned int j = y*TILE_SIZE + tgx;
                    float4 tempPosq = posq[j];
                    localData[get_local_id(0)].x = tempPosq.x;
                    localData[get_local_id(0)].y = tempPosq.y;
                    localData[get_local_id(0)].z = tempPosq.z;
                    localData[get_local_id(0)].q = tempPosq.w;
                    localData[get_local_id(0)].bornRadius = global_bornRadii[j];
                }
                localData[get_local_id(0)].fx = 0.0f;
                localData[get_local_id(0)].fy = 0.0f;
                localData[get_local_id(0)].fz = 0.0f;
                localData[get_local_id(0)].fw = 0.0f;
#ifdef USE_CUTOFF
                unsigned int flags = (numTiles <= maxTiles ? interactionFlags[pos] : 0xFFFFFFFF);
                if (flags != 0xFFFFFFFF && false) { // TODO: Fix this: should be checking for exclusions
                    if (flags == 0) {
                        // No interactions in this tile.
                    }
                    else {
                        // Compute only a subset of the interactions in this tile.
446

447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
                        for (unsigned int j = 0; j < TILE_SIZE; j++) {
                            if ((flags&(1<<j)) != 0) {
                                float4 posq2 = (float4) (localData[tbx+j].x, localData[tbx+j].y, localData[tbx+j].z, localData[tbx+j].q);
                                float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
#ifdef USE_PERIODIC
                                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
                                float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                                float invR = RSQRT(r2);
                                float r = RECIP(invR);
                                float bornRadius2 = localData[tbx+j].bornRadius;
                                float alpha2_ij = bornRadius1*bornRadius2;
                                float D_ij = r2*RECIP(4.0f*alpha2_ij);
                                float expTerm = EXP(-D_ij);
                                float denominator2 = r2 + alpha2_ij*expTerm;
                                float denominator = SQRT(denominator2);
                                float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
                                float Gpol = tempEnergy*RECIP(denominator2);
                                float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
                                float dEdR = Gpol*(1.0f - 0.25f*expTerm);
469
#ifdef USE_CUTOFF
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
                                if (atom1 >= NUM_ATOMS || y*TILE_SIZE+j >= NUM_ATOMS || r2 > CUTOFF_SQUARED) {
#else
                                if (atom1 >= NUM_ATOMS || y*TILE_SIZE+j >= NUM_ATOMS) {
#endif
                                    dEdR = 0.0f;
                                    dGpol_dalpha2_ij = 0.0f;
                                    tempEnergy = 0.0f;
                                }
                                energy += tempEnergy;
                                force.w += dGpol_dalpha2_ij*bornRadius2;
                                delta.xyz *= dEdR;
                                force.xyz -= delta.xyz;
                                tempBuffer[get_local_id(0)] = (float4) (delta.xyz, dGpol_dalpha2_ij*bornRadius1);

                                // Sum the forces on atom j.

                                if (tgx % 4 == 0)
487
                                    tempBuffer[get_local_id(0)] += tempBuffer[get_local_id(0)+1]+tempBuffer[get_local_id(0)+2]+tempBuffer[get_local_id(0)+3];
488
                                if (tgx == 0) {
489
                                    float4 sum = tempBuffer[get_local_id(0)]+tempBuffer[get_local_id(0)+4]+tempBuffer[get_local_id(0)+8]+tempBuffer[get_local_id(0)+12]+tempBuffer[get_local_id(0)+16]+tempBuffer[get_local_id(0)+20]+tempBuffer[get_local_id(0)+24]+tempBuffer[get_local_id(0)+28];
490
491
492
493
494
495
496
497
                                    localData[tbx+j].fx += sum.x;
                                    localData[tbx+j].fy += sum.y;
                                    localData[tbx+j].fz += sum.z;
                                    localData[tbx+j].fw += sum.w;
                                }
                            }
                        }
                    }
498
                }
499
500
501
502
                else
#endif
                {
                    // Compute the full set of interactions in this tile.
503

504
                    unsigned int tj = tgx;
505
                    for (unsigned int j = 0; j < TILE_SIZE; j++) {
506
507
                        if (atom1 < NUM_ATOMS && y*TILE_SIZE+tj < NUM_ATOMS) {
                            float4 posq2 = (float4) (localData[tbx+tj].x, localData[tbx+tj].y, localData[tbx+tj].z, localData[tbx+tj].q);
508
                            float4 delta = (float4) (posq2.xyz - posq1.xyz, 0.0f);
509
#ifdef USE_PERIODIC
510
511
512
                            delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                            delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                            delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
513
514
#endif
                            float r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
515
516
                            float invR = RSQRT(r2);
                            float r = RECIP(invR);
517
                            float bornRadius2 = localData[tbx+tj].bornRadius;
518
                            float alpha2_ij = bornRadius1*bornRadius2;
519
                            float D_ij = r2*RECIP(4.0f*alpha2_ij);
520
                            float expTerm = EXP(-D_ij);
521
                            float denominator2 = r2 + alpha2_ij*expTerm;
522
                            float denominator = SQRT(denominator2);
523
524
                            float tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
                            float Gpol = tempEnergy*RECIP(denominator2);
525
526
527
                            float dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
                            float dEdR = Gpol*(1.0f - 0.25f*expTerm);
#ifdef USE_CUTOFF
528
                            if (r2 > CUTOFF_SQUARED) {
529
                                dEdR = 0.0f;
530
                                tempEnergy  = 0.0f;
531
                                dGpol_dalpha2_ij = 0.0f;
532
                            }
533
#endif
534
                            force.w += dGpol_dalpha2_ij*bornRadius2;
535
                            energy += tempEnergy;
536
537
                            delta.xyz *= dEdR;
                            force.xyz -= delta.xyz;
538
539
540
541
                            localData[tbx+tj].fx += delta.x;
                            localData[tbx+tj].fy += delta.y;
                            localData[tbx+tj].fz += delta.z;
                            localData[tbx+tj].fw += dGpol_dalpha2_ij*bornRadius1;
542
                        }
543
                        tj = (tj + 1) & (TILE_SIZE - 1);
544
545
546
                    }
                }
            }
547
548
549
550
551
        }
        
        // Write results.  We need to coordinate between warps to make sure no two of them
        // ever try to write to the same piece of memory at the same time.
        
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
#ifdef SUPPORTS_64_BIT_ATOMICS
        if (pos < end) {
            const unsigned int offset = x*TILE_SIZE + tgx;
            atom_add(&forceBuffers[offset], (long) (force.x*0xFFFFFFFF));
            atom_add(&forceBuffers[offset+PADDED_NUM_ATOMS], (long) (force.y*0xFFFFFFFF));
            atom_add(&forceBuffers[offset+2*PADDED_NUM_ATOMS], (long) (force.z*0xFFFFFFFF));
            atom_add(&global_bornForce[offset], (long) (force.w*0xFFFFFFFF));
        }
        if (pos < end && x != y) {
            const unsigned int offset = y*TILE_SIZE + tgx;
            atom_add(&forceBuffers[offset], (long) (localData[get_local_id(0)].fx*0xFFFFFFFF));
            atom_add(&forceBuffers[offset+PADDED_NUM_ATOMS], (long) (localData[get_local_id(0)].fy*0xFFFFFFFF));
            atom_add(&forceBuffers[offset+2*PADDED_NUM_ATOMS], (long) (localData[get_local_id(0)].fz*0xFFFFFFFF));
            atom_add(&global_bornForce[offset], (long) (localData[get_local_id(0)].fw*0xFFFFFFFF));
        }
#else
568
569
570
571
572
573
574
575
576
        int writeX = (pos < end ? x : -1);
        int writeY = (pos < end && x != y ? y : -1);
        if (tgx == 0)
            reservedBlocks[localGroupIndex] = (int2)(writeX, writeY);
        bool done = false;
        int doneIndex = 0;
        int checkIndex = 0;
        while (true) {
            // See if any warp still needs to write its data.
577

578
579
580
581
582
583
584
            bool allDone = true;
            barrier(CLK_LOCAL_MEM_FENCE);
            while (doneIndex < WARPS_PER_GROUP && allDone) {
                if (reservedBlocks[doneIndex].x != -1)
                    allDone = false;
                else
                    doneIndex++;
585
            }
586
587
588
589
590
            if (allDone)
                break;
            if (!done) {
                // See whether this warp can write its data.  This requires that no previous warp
                // is trying to write to the same block of the buffer.
591

592
593
594
595
596
597
598
599
600
601
                bool canWrite = (writeX != -1);
                while (checkIndex < localGroupIndex && canWrite) {
                    if ((reservedBlocks[checkIndex].x == x || reservedBlocks[checkIndex].y == x) ||
                            (writeY != -1 && (reservedBlocks[checkIndex].x == y || reservedBlocks[checkIndex].y == y)))
                        canWrite = false;
                    else
                        checkIndex++;
                }
                if (canWrite) {
                    // Write the data to global memory, then mark this warp as done.
602

603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
                    if (writeX > -1) {
                        const unsigned int offset = x*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        forceBuffers[offset].xyz += force.xyz;
                        global_bornForce[offset] += force.w;
                    }
                    if (writeY > -1) {
                        const unsigned int offset = y*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        forceBuffers[offset] += (float4) (localData[get_local_id(0)].fx, localData[get_local_id(0)].fy, localData[get_local_id(0)].fz, 0.0f);
                        global_bornForce[offset] += localData[get_local_id(0)].fw;
                    }
                    done = true;
                    if (tgx == 0)
                        reservedBlocks[localGroupIndex] = (int2)(-1, -1);
                }
            }
618
        }
619
#endif
620
        lasty = y;
621
        pos++;
622
    } while (pos < end);
623
624
    energyBuffer[get_global_id(0)] += energy;
}