gbsaObc_cpu.cl 16 KB
Newer Older
1
2
3
#define TILE_SIZE 32

typedef struct {
4
5
    real x, y, z;
    real q;
6
    float radius, scaledRadius;
7
    real bornSum;
8
} AtomData1;
9
10
11
12
13

/**
 * Compute the Born sum.
 */

14
__kernel void computeBornSum(__global real* restrict global_bornSum, __global const real4* restrict posq, __global const float2* restrict global_params,
15
#ifdef USE_CUTOFF
16
        __global const ushort2* restrict tiles, __global const unsigned int* restrict interactionCount, real4 periodicBoxSize, real4 invPeriodicBoxSize, unsigned int maxTiles, __global const unsigned int* restrict interactionFlags) {
17
18
19
20
21
22
23
24
25
26
27
28
#else
        unsigned int numTiles) {
#endif
#ifdef USE_CUTOFF
    unsigned int numTiles = interactionCount[0];
    unsigned int pos = get_group_id(0)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/get_num_groups(0);
    unsigned int end = (get_group_id(0)+1)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/get_num_groups(0);
#else
    unsigned int pos = get_group_id(0)*numTiles/get_num_groups(0);
    unsigned int end = (get_group_id(0)+1)*numTiles/get_num_groups(0);
#endif
    unsigned int lasty = 0xFFFFFFFF;
29
    __local AtomData1 localData[TILE_SIZE];
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

    while (pos < end) {
        // Extract the coordinates of this tile
        unsigned int x, y;
#ifdef USE_CUTOFF
        if (numTiles <= maxTiles) {
            ushort2 tileIndices = tiles[pos];
            x = tileIndices.x;
            y = tileIndices.y;
        }
        else
#endif
        {
            y = (unsigned int) floor(NUM_BLOCKS+0.5f-sqrt((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
            x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
45
46
            if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                y += (x < y ? -1 : 1);
47
48
49
50
51
52
53
54
55
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
            }
        }

        // Load the data for this tile if we don't already have it cached.

        if (lasty != y) {
            for (int localAtomIndex = 0; localAtomIndex < TILE_SIZE; localAtomIndex++) {
                unsigned int j = y*TILE_SIZE + localAtomIndex;
56
                real4 tempPosq = posq[j];
57
58
59
60
61
62
63
64
65
66
67
68
69
70
                localData[localAtomIndex].x = tempPosq.x;
                localData[localAtomIndex].y = tempPosq.y;
                localData[localAtomIndex].z = tempPosq.z;
                localData[localAtomIndex].q = tempPosq.w;
                float2 tempParams = global_params[j];
                localData[localAtomIndex].radius = tempParams.x;
                localData[localAtomIndex].scaledRadius = tempParams.y;
            }
        }
        if (x == y) {
            // This tile is on the diagonal.

            for (unsigned int tgx = 0; tgx < TILE_SIZE; tgx++) {
                unsigned int atom1 = x*TILE_SIZE+tgx;
71
72
                real bornSum = 0.0f;
                real4 posq1 = posq[atom1];
73
74
                float2 params1 = global_params[atom1];
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
75
76
                    real4 posq2 = (real4) (localData[j].x, localData[j].y, localData[j].z, localData[j].q);
                    real4 delta = (real4) (posq2.xyz - posq1.xyz, 0);
77
78
79
#ifdef USE_PERIODIC
                    delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
#endif
80
                    real r2 = dot(delta.xyz, delta.xyz);
81
82
83
84
85
#ifdef USE_CUTOFF
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
#else
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
#endif
86
87
                        real invR = RSQRT(r2);
                        real r = RECIP(invR);
88
                        float2 params2 = (float2) (localData[j].radius, localData[j].scaledRadius);
89
                        real rScaledRadiusJ = r+params2.y;
90
                        if ((j != tgx) && (params1.x < rScaledRadiusJ)) {
91
92
93
94
95
                            real l_ij = RECIP(max((real) params1.x, fabs(r-params2.y)));
                            real u_ij = RECIP(rScaledRadiusJ);
                            real l_ij2 = l_ij*l_ij;
                            real u_ij2 = u_ij*u_ij;
                            real ratio = LOG(u_ij * RECIP(l_ij));
96
97
                            bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
                                             (0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2);
Mark Friedrichs's avatar
Mark Friedrichs committed
98
                            if (params1.x < params2.y-r)
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
                                bornSum += 2.0f*(RECIP(params1.x)-l_ij);
                        }
                    }
                }

                // Write results.

                unsigned int offset = x*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                global_bornSum[offset] += bornSum;
            }
        }
        else {
            // This is an off-diagonal tile.

            for (int tgx = 0; tgx < TILE_SIZE; tgx++)
                localData[tgx].bornSum = 0.0f;

            // Compute the full set of interactions in this tile.

            for (unsigned int tgx = 0; tgx < TILE_SIZE; tgx++) {
                unsigned int atom1 = x*TILE_SIZE+tgx;
120
121
                real bornSum = 0.0f;
                real4 posq1 = posq[atom1];
122
123
                float2 params1 = global_params[atom1];
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
124
125
                    real4 posq2 = (real4) (localData[j].x, localData[j].y, localData[j].z, localData[j].q);
                    real4 delta = (real4) (posq2.xyz - posq1.xyz, 0);
126
127
128
#ifdef USE_PERIODIC
                    delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
#endif
129
                    real r2 = dot(delta.xyz, delta.xyz);
130
131
132
133
134
#ifdef USE_CUTOFF
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
#else
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
#endif
135
136
                        real invR = RSQRT(r2);
                        real r = RECIP(invR);
137
138
139


                        float2 params2 = (float2) (localData[j].radius, localData[j].scaledRadius);
140
                        real rScaledRadiusJ = r+params2.y;
141
                        if (params1.x < rScaledRadiusJ) {
142
143
144
145
146
                            real l_ij = RECIP(max((real) params1.x, fabs(r-params2.y)));
                            real u_ij = RECIP(rScaledRadiusJ);
                            real l_ij2 = l_ij*l_ij;
                            real u_ij2 = u_ij*u_ij;
                            real ratio = LOG(u_ij * RECIP(l_ij));
147
148
                            bornSum += l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
                                             (0.25f*params2.y*params2.y*invR)*(l_ij2-u_ij2);
Mark Friedrichs's avatar
Mark Friedrichs committed
149
                            if (params1.x < params2.y-r)
150
151
                                bornSum += 2.0f*(RECIP(params1.x)-l_ij);
                        }
152
                        real rScaledRadiusI = r+params1.y;
153
                        if (params2.x < rScaledRadiusI) {
154
155
156
157
158
159
                            real l_ij = RECIP(max((real) params2.x, fabs(r-params1.y)));
                            real u_ij = RECIP(rScaledRadiusI);
                            real l_ij2 = l_ij*l_ij;
                            real u_ij2 = u_ij*u_ij;
                            real ratio = LOG(u_ij * RECIP(l_ij));
                            real term = l_ij - u_ij + 0.25f*r*(u_ij2-l_ij2) + (0.50f*invR*ratio) +
160
                                             (0.25f*params1.y*params1.y*invR)*(l_ij2-u_ij2);
Mark Friedrichs's avatar
Mark Friedrichs committed
161
                            if (params2.x < params1.y-r)
162
163
164
165
166
167
168
169
170
                                term += 2.0f*(RECIP(params2.x)-l_ij);
                            localData[j].bornSum += term;
                        }
                    }
                }

               // Write results for atom1.

                unsigned int offset = atom1 + get_group_id(0)*PADDED_NUM_ATOMS;
171
                global_bornSum[offset] += bornSum;
172
173
            }

174
            // Write results
175

176
177
178
179
            for (int tgx = 0; tgx < TILE_SIZE; tgx++) {
                unsigned int offset = y*TILE_SIZE+tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                global_bornSum[offset] += localData[tgx].bornSum;
            }
180
181
182
183
184
185
        }
        lasty = y;
        pos++;
    }
}

186
typedef struct {
187
188
189
190
    real x, y, z;
    real q;
    real fx, fy, fz, fw;
    real bornRadius;
191
192
} AtomData2;

193
194
195
196
/**
 * First part of computing the GBSA interaction.
 */

197
198
__kernel void computeGBSAForce1(__global real4* restrict forceBuffers, __global real* restrict global_bornForce,
        __global real* restrict energyBuffer, __global const real4* restrict posq, __global const real* restrict global_bornRadii,
199
#ifdef USE_CUTOFF
200
        __global const ushort2* restrict tiles, __global const unsigned int* restrict interactionCount, real4 periodicBoxSize, real4 invPeriodicBoxSize, unsigned int maxTiles, __global const unsigned int* restrict interactionFlags) {
201
202
203
204
205
206
207
208
209
210
211
#else
        unsigned int numTiles) {
#endif
#ifdef USE_CUTOFF
    unsigned int numTiles = interactionCount[0];
    unsigned int pos = get_group_id(0)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/get_num_groups(0);
    unsigned int end = (get_group_id(0)+1)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/get_num_groups(0);
#else
    unsigned int pos = get_group_id(0)*numTiles/get_num_groups(0);
    unsigned int end = (get_group_id(0)+1)*numTiles/get_num_groups(0);
#endif
212
    real energy = 0.0f;
213
    unsigned int lasty = 0xFFFFFFFF;
214
    __local AtomData2 localData[TILE_SIZE];
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229

    while (pos < end) {
        // Extract the coordinates of this tile
        unsigned int x, y;
#ifdef USE_CUTOFF
        if (numTiles <= maxTiles) {
            ushort2 tileIndices = tiles[pos];
            x = tileIndices.x;
            y = tileIndices.y;
        }
        else
#endif
        {
            y = (unsigned int) floor(NUM_BLOCKS+0.5f-sqrt((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
            x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
230
231
            if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                y += (x < y ? -1 : 1);
232
233
234
235
236
237
238
239
240
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
            }
        }

        // Load the data for this tile if we don't already have it cached.

        if (lasty != y) {
            for (int localAtomIndex = 0; localAtomIndex < TILE_SIZE; localAtomIndex++) {
                unsigned int j = y*TILE_SIZE + localAtomIndex;
241
                real4 tempPosq = posq[j];
242
243
244
245
246
247
248
249
250
251
252
253
                localData[localAtomIndex].x = tempPosq.x;
                localData[localAtomIndex].y = tempPosq.y;
                localData[localAtomIndex].z = tempPosq.z;
                localData[localAtomIndex].q = tempPosq.w;
                localData[localAtomIndex].bornRadius = global_bornRadii[j];
            }
        }
        if (x == y) {
            // This tile is on the diagonal.

            for (unsigned int tgx = 0; tgx < TILE_SIZE; tgx++) {
                unsigned int atom1 = x*TILE_SIZE+tgx;
254
255
256
                real4 force = 0.0f;
                real4 posq1 = posq[atom1];
                real bornRadius1 = global_bornRadii[atom1];
257
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
258
259
                    real4 posq2 = (real4) (localData[j].x, localData[j].y, localData[j].z, localData[j].q);
                    real4 delta = (real4) (posq2.xyz - posq1.xyz, 0);
260
261
262
#ifdef USE_PERIODIC
                    delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
#endif
263
                    real r2 = dot(delta.xyz, delta.xyz);
264
265
266
267
268
#ifdef USE_CUTOFF
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
#else
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
#endif
269
270
271
272
273
274
275
276
277
278
279
                        real invR = RSQRT(r2);
                        real r = RECIP(invR);
                        real bornRadius2 = localData[j].bornRadius;
                        real alpha2_ij = bornRadius1*bornRadius2;
                        real D_ij = r2*RECIP(4.0f*alpha2_ij);
                        real expTerm = EXP(-D_ij);
                        real denominator2 = r2 + alpha2_ij*expTerm;
                        real denominator = SQRT(denominator2);
                        real tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
                        real Gpol = tempEnergy*RECIP(denominator2);
                        real dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
280
                        force.w += dGpol_dalpha2_ij*bornRadius2;
281
                        real dEdR = Gpol*(1.0f - 0.25f*expTerm);
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
                        energy += 0.5f*tempEnergy;
                        force.xyz -= delta.xyz*dEdR;
                    }
                }

                // Write results.

                unsigned int offset = x*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                forceBuffers[offset].xyz = forceBuffers[offset].xyz+force.xyz;
                global_bornForce[offset] += force.w;
            }
        }
        else {
            // This is an off-diagonal tile.

            for (int tgx = 0; tgx < TILE_SIZE; tgx++) {
                localData[tgx].fx = 0.0f;
                localData[tgx].fy = 0.0f;
                localData[tgx].fz = 0.0f;
                localData[tgx].fw = 0.0f;
            }

            // Compute the full set of interactions in this tile.

            for (unsigned int tgx = 0; tgx < TILE_SIZE; tgx++) {
                unsigned int atom1 = x*TILE_SIZE+tgx;
308
309
310
                real4 force = 0.0f;
                real4 posq1 = posq[atom1];
                real bornRadius1 = global_bornRadii[atom1];
311
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
312
313
                    real4 posq2 = (real4) (localData[j].x, localData[j].y, localData[j].z, localData[j].q);
                    real4 delta = (real4) (posq2.xyz - posq1.xyz, 0);
314
315
316
#ifdef USE_PERIODIC
                    delta.xyz -= floor(delta.xyz*invPeriodicBoxSize.xyz+0.5f)*periodicBoxSize.xyz;
#endif
317
                    real r2 = dot(delta.xyz, delta.xyz);
318
319
320
321
322
#ifdef USE_CUTOFF
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS && r2 < CUTOFF_SQUARED) {
#else
                    if (atom1 < NUM_ATOMS && y*TILE_SIZE+j < NUM_ATOMS) {
#endif
323
324
325
326
327
328
329
330
331
332
333
                        real invR = RSQRT(r2);
                        real r = RECIP(invR);
                        real bornRadius2 = localData[j].bornRadius;
                        real alpha2_ij = bornRadius1*bornRadius2;
                        real D_ij = r2*RECIP(4.0f*alpha2_ij);
                        real expTerm = EXP(-D_ij);
                        real denominator2 = r2 + alpha2_ij*expTerm;
                        real denominator = SQRT(denominator2);
                        real tempEnergy = (PREFACTOR*posq1.w*posq2.w)*RECIP(denominator);
                        real Gpol = tempEnergy*RECIP(denominator2);
                        real dGpol_dalpha2_ij = -0.5f*Gpol*expTerm*(1.0f+D_ij);
334
                        force.w += dGpol_dalpha2_ij*bornRadius2;
335
                        real dEdR = Gpol*(1.0f - 0.25f*expTerm);
336
337
338
339
340
341
342
343
344
345
346
347
348
349
                        energy += tempEnergy;
                        delta.xyz *= dEdR;
                        force.xyz -= delta.xyz;
                        localData[j].fx += delta.x;
                        localData[j].fy += delta.y;
                        localData[j].fz += delta.z;
                        localData[j].fw += dGpol_dalpha2_ij*bornRadius1;
                    }
                }

                // Write results for atom1.

                unsigned int offset = atom1 + get_group_id(0)*PADDED_NUM_ATOMS;
                forceBuffers[offset].xyz = forceBuffers[offset].xyz+force.xyz;
350
                global_bornForce[offset] += force.w;
351
352
            }

353
            // Write results
354

355
356
            for (int tgx = 0; tgx < TILE_SIZE; tgx++) {
                unsigned int offset = y*TILE_SIZE+tgx + get_group_id(0)*PADDED_NUM_ATOMS;
357
                real4 f = forceBuffers[offset];
358
359
360
361
362
363
                f.x += localData[tgx].fx;
                f.y += localData[tgx].fy;
                f.z += localData[tgx].fz;
                forceBuffers[offset] = f;
                global_bornForce[offset] += localData[tgx].fw;
            }
364
365
366
367
368
369
        }
        lasty = y;
        pos++;
    }
    energyBuffer[get_global_id(0)] += energy;
}