nonbonded_nvidia.cl 16.3 KB
Newer Older
1
#pragma OPENCL EXTENSION cl_khr_global_int32_base_atomics : enable
2
3
4
#ifdef SUPPORTS_64_BIT_ATOMICS
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
#endif
5
#define TILE_SIZE 32
6
#define WARPS_PER_GROUP (FORCE_WORK_GROUP_SIZE/TILE_SIZE)
7

8
typedef struct {
9
10
11
    real x, y, z;
    real q;
    real fx, fy, fz;
12
    ATOM_PARAMETER_DATA
13
#ifndef PARAMETER_SIZE_IS_EVEN
14
    real padding;
15
#endif
16
17
} AtomData;

18
19
20
/**
 * Compute nonbonded interactions.
 */
21
22
__kernel void computeNonbonded(
#ifdef SUPPORTS_64_BIT_ATOMICS
23
        __global long* restrict forceBuffers,
24
#else
25
        __global real4* restrict forceBuffers,
26
#endif
27
        __global real* restrict energyBuffer, __global const real4* restrict posq, __global const unsigned int* restrict exclusions,
28
        __global const unsigned int* restrict exclusionIndices, __global const unsigned int* restrict exclusionRowIndices,
29
        unsigned int startTileIndex, unsigned int endTileIndex,
30
#ifdef USE_CUTOFF
31
        __global const ushort2* restrict tiles, __global const unsigned int* restrict interactionCount, real4 periodicBoxSize, real4 invPeriodicBoxSize, unsigned int maxTiles, __global const unsigned int* restrict interactionFlags
32
33
#else
        unsigned int numTiles
34
#endif
35
        PARAMETER_ARGUMENTS) {
36
37
    unsigned int totalWarps = get_global_size(0)/TILE_SIZE;
    unsigned int warp = get_global_id(0)/TILE_SIZE;
38
39
#ifdef USE_CUTOFF
    unsigned int numTiles = interactionCount[0];
40
41
    unsigned int pos = (numTiles > maxTiles ? startTileIndex+warp*(endTileIndex-startTileIndex)/totalWarps : warp*numTiles/totalWarps);
    unsigned int end = (numTiles > maxTiles ? startTileIndex+(warp+1)*(endTileIndex-startTileIndex)/totalWarps : (warp+1)*numTiles/totalWarps);
42
#else
43
44
    unsigned int pos = startTileIndex+warp*numTiles/totalWarps;
    unsigned int end = startTileIndex+(warp+1)*numTiles/totalWarps;
45
#endif
46
    real energy = 0;
47
    __local AtomData localData[FORCE_WORK_GROUP_SIZE];
48
    __local real tempBuffer[3*FORCE_WORK_GROUP_SIZE];
49
50
    __local unsigned int exclusionRange[2*WARPS_PER_GROUP];
    __local int exclusionIndex[WARPS_PER_GROUP];
51
52
53
    __local int2* reservedBlocks = (__local int2*) exclusionRange;
    
    do {
54
        // Extract the coordinates of this tile
55
56
57
        const unsigned int tgx = get_local_id(0) & (TILE_SIZE-1);
        const unsigned int tbx = get_local_id(0) - tgx;
        const unsigned int localGroupIndex = get_local_id(0)/TILE_SIZE;
58
        unsigned int x, y;
59
        real4 force = 0;
60
        if (pos < end) {
61
#ifdef USE_CUTOFF
62
63
64
65
66
67
            if (numTiles <= maxTiles) {
                ushort2 tileIndices = tiles[pos];
                x = tileIndices.x;
                y = tileIndices.y;
            }
            else
68
#endif
69
70
            {
                y = (unsigned int) floor(NUM_BLOCKS+0.5f-SQRT((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
71
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
72
73
74
75
                if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                    y += (x < y ? -1 : 1);
                    x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
                }
76
            }
77
            unsigned int atom1 = x*TILE_SIZE + tgx;
78
            real4 posq1 = posq[atom1];
79
            LOAD_ATOM1_PARAMETERS
80

81
            // Locate the exclusion data for this tile.
82
83

#ifdef USE_EXCLUSIONS
84
85
86
87
            if (tgx < 2)
                exclusionRange[2*localGroupIndex+tgx] = exclusionRowIndices[x+tgx];
            if (tgx == 0)
                exclusionIndex[localGroupIndex] = -1;
88
            for (unsigned int i = exclusionRange[2*localGroupIndex]+tgx; i < exclusionRange[2*localGroupIndex+1]; i += TILE_SIZE)
89
90
91
                if (exclusionIndices[i] == y)
                    exclusionIndex[localGroupIndex] = i*TILE_SIZE;
            bool hasExclusions = (exclusionIndex[localGroupIndex] > -1);
92
#else
93
            bool hasExclusions = false;
94
#endif
95
96
97
98
            if (pos >= end)
                ; // This warp is done.
            else if (x == y) {
                // This tile is on the diagonal.
99

100
101
102
103
104
105
                const unsigned int localAtomIndex = get_local_id(0);
                localData[localAtomIndex].x = posq1.x;
                localData[localAtomIndex].y = posq1.y;
                localData[localAtomIndex].z = posq1.z;
                localData[localAtomIndex].q = posq1.w;
                LOAD_LOCAL_PARAMETERS_FROM_1
106
#ifdef USE_EXCLUSIONS
107
                unsigned int excl = exclusions[exclusionIndex[localGroupIndex]+tgx];
108
#endif
109
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
110
#ifdef USE_EXCLUSIONS
111
                    bool isExcluded = !(excl & 0x1);
112
#endif
113
                    int atom2 = tbx+j;
114
115
                    real4 posq2 = (real4) (localData[atom2].x, localData[atom2].y, localData[atom2].z, localData[atom2].q);
                    real4 delta = (real4) (posq2.xyz - posq1.xyz, 0);
116
#ifdef USE_PERIODIC
117
118
119
                    delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                    delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                    delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
120
#endif
121
122
123
                    real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                    real invR = RSQRT(r2);
                    real r = RECIP(invR);
124
125
                    LOAD_ATOM2_PARAMETERS
                    atom2 = y*TILE_SIZE+j;
126
#ifdef USE_SYMMETRIC
127
                    real dEdR = 0;
128
#else
129
130
                    real4 dEdR1 = (real4) 0;
                    real4 dEdR2 = (real4) 0;
131
#endif
132
                    real tempEnergy = 0;
133
134
                    COMPUTE_INTERACTION
                    energy += 0.5f*tempEnergy;
135
#ifdef USE_SYMMETRIC
136
                    force.xyz -= delta.xyz*dEdR;
137
#else
138
                    force.xyz -= dEdR1.xyz;
139
#endif
140
141
142
143
#ifdef USE_EXCLUSIONS
                    excl >>= 1;
#endif
                }
144
            }
145
146
            else {
                // This is an off-diagonal tile.
147

148
                const unsigned int localAtomIndex = get_local_id(0);
149
                unsigned int j = y*TILE_SIZE + tgx;
150
                real4 tempPosq = posq[j];
151
152
153
154
155
                localData[localAtomIndex].x = tempPosq.x;
                localData[localAtomIndex].y = tempPosq.y;
                localData[localAtomIndex].z = tempPosq.z;
                localData[localAtomIndex].q = tempPosq.w;
                LOAD_LOCAL_PARAMETERS_FROM_GLOBAL
156
157
158
                localData[localAtomIndex].fx = 0;
                localData[localAtomIndex].fy = 0;
                localData[localAtomIndex].fz = 0;
159
160
161
162
163
164
165
166
#ifdef USE_CUTOFF
                unsigned int flags = (numTiles <= maxTiles ? interactionFlags[pos] : 0xFFFFFFFF);
                if (!hasExclusions && flags != 0xFFFFFFFF) {
                    if (flags == 0) {
                        // No interactions in this tile.
                    }
                    else {
                        // Compute only a subset of the interactions in this tile.
167

Peter Eastman's avatar
Peter Eastman committed
168
                        for (j = 0; j < TILE_SIZE; j++) {
169
170
171
                            if ((flags&(1<<j)) != 0) {
                                bool isExcluded = false;
                                int atom2 = tbx+j;
172
173
                                int bufferIndex = 3*get_local_id(0);
#ifdef USE_SYMMETRIC
174
                                real dEdR = 0;
175
#else
176
177
                                real4 dEdR1 = (real4) 0;
                                real4 dEdR2 = (real4) 0;
178
#endif
179
180
181
                                real tempEnergy = 0;
                                real4 posq2 = (real4) (localData[atom2].x, localData[atom2].y, localData[atom2].z, localData[atom2].q);
                                real4 delta = (real4) (posq2.xyz - posq1.xyz, 0);
182
#ifdef USE_PERIODIC
183
184
185
                                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
186
#endif
187
                                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
188
189
190
#ifdef USE_CUTOFF
                                if (r2 < CUTOFF_SQUARED) {
#endif
191
192
                                    real invR = RSQRT(r2);
                                    real r = RECIP(invR);
193
194
195
196
197
198
                                    LOAD_ATOM2_PARAMETERS
                                    atom2 = y*TILE_SIZE+j;
                                    COMPUTE_INTERACTION
                                    energy += tempEnergy;
#ifdef USE_CUTOFF
                                }
199
200
#endif
#ifdef USE_SYMMETRIC
201
202
203
204
205
                                delta.xyz *= dEdR;
                                force.xyz -= delta.xyz;
                                tempBuffer[bufferIndex] = delta.x;
                                tempBuffer[bufferIndex+1] = delta.y;
                                tempBuffer[bufferIndex+2] = delta.z;
206
#else
207
208
209
210
                                force.xyz -= dEdR1.xyz;
                                tempBuffer[bufferIndex] = dEdR2.x;
                                tempBuffer[bufferIndex+1] = dEdR2.y;
                                tempBuffer[bufferIndex+2] = dEdR2.z;
211
#endif
212

213
                                // Sum the forces on atom2.
214

215
                                if (tgx % 4 == 0) {
216
217
218
                                    tempBuffer[bufferIndex] += tempBuffer[bufferIndex+3]+tempBuffer[bufferIndex+6]+tempBuffer[bufferIndex+9];
                                    tempBuffer[bufferIndex+1] += tempBuffer[bufferIndex+4]+tempBuffer[bufferIndex+7]+tempBuffer[bufferIndex+10];
                                    tempBuffer[bufferIndex+2] += tempBuffer[bufferIndex+5]+tempBuffer[bufferIndex+8]+tempBuffer[bufferIndex+11];
219
220
                                }
                                if (tgx == 0) {
221
222
223
                                    localData[tbx+j].fx += tempBuffer[bufferIndex]+tempBuffer[bufferIndex+12]+tempBuffer[bufferIndex+24]+tempBuffer[bufferIndex+36]+tempBuffer[bufferIndex+48]+tempBuffer[bufferIndex+60]+tempBuffer[bufferIndex+72]+tempBuffer[bufferIndex+84];
                                    localData[tbx+j].fy += tempBuffer[bufferIndex+1]+tempBuffer[bufferIndex+13]+tempBuffer[bufferIndex+25]+tempBuffer[bufferIndex+37]+tempBuffer[bufferIndex+49]+tempBuffer[bufferIndex+61]+tempBuffer[bufferIndex+73]+tempBuffer[bufferIndex+85];
                                    localData[tbx+j].fz += tempBuffer[bufferIndex+2]+tempBuffer[bufferIndex+14]+tempBuffer[bufferIndex+26]+tempBuffer[bufferIndex+38]+tempBuffer[bufferIndex+50]+tempBuffer[bufferIndex+62]+tempBuffer[bufferIndex+74]+tempBuffer[bufferIndex+86];
224
                                }
225
                            }
226
227
228
                        }
                    }
                }
229
                else
230
#endif
231
232
                {
                    // Compute the full set of interactions in this tile.
233

234
#ifdef USE_EXCLUSIONS
235
236
                    unsigned int excl = (hasExclusions ? exclusions[exclusionIndex[localGroupIndex]+tgx] : 0xFFFFFFFF);
                    excl = (excl >> tgx) | (excl << (TILE_SIZE - tgx));
237
#endif
238
                    unsigned int tj = tgx;
Peter Eastman's avatar
Peter Eastman committed
239
                    for (j = 0; j < TILE_SIZE; j++) {
240
#ifdef USE_EXCLUSIONS
241
                        bool isExcluded = !(excl & 0x1);
242
#endif
243
                        int atom2 = tbx+tj;
244
245
                        real4 posq2 = (real4) (localData[atom2].x, localData[atom2].y, localData[atom2].z, localData[atom2].q);
                        real4 delta = (real4) (posq2.xyz - posq1.xyz, 0);
246
#ifdef USE_PERIODIC
247
248
249
                        delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                        delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                        delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
250
#endif
251
                        real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
252
253
254
#ifdef USE_CUTOFF
                        if (r2 < CUTOFF_SQUARED) {
#endif
255
256
                            real invR = RSQRT(r2);
                            real r = RECIP(invR);
257
258
                            LOAD_ATOM2_PARAMETERS
                            atom2 = y*TILE_SIZE+tj;
259
#ifdef USE_SYMMETRIC
260
                            real dEdR = 0;
261
#else
262
263
                            real4 dEdR1 = (real4) 0;
                            real4 dEdR2 = (real4) 0;
264
#endif
265
                            real tempEnergy = 0;
266
267
                            COMPUTE_INTERACTION
                            energy += tempEnergy;
268
#ifdef USE_SYMMETRIC
269
270
271
272
273
                            delta.xyz *= dEdR;
                            force.xyz -= delta.xyz;
                            localData[tbx+tj].fx += delta.x;
                            localData[tbx+tj].fy += delta.y;
                            localData[tbx+tj].fz += delta.z;
274
#else
275
276
277
278
279
280
281
                            force.xyz -= dEdR1.xyz;
                            localData[tbx+tj].fx += dEdR2.x;
                            localData[tbx+tj].fy += dEdR2.y;
                            localData[tbx+tj].fz += dEdR2.z;
#endif
#ifdef USE_CUTOFF
                        }
282
#endif
283
#ifdef USE_EXCLUSIONS
284
                        excl >>= 1;
Peter Eastman's avatar
Bug fix  
Peter Eastman committed
285
#endif
286
287
                        tj = (tj + 1) & (TILE_SIZE - 1);
                    }
288
289
                }
            }
290
291
292
293
294
        }
        
        // Write results.  We need to coordinate between warps to make sure no two of them
        // ever try to write to the same piece of memory at the same time.
        
295
296
297
298
299
300
301
302
303
304
305
306
307
308
#ifdef SUPPORTS_64_BIT_ATOMICS
        if (pos < end) {
            const unsigned int offset = x*TILE_SIZE + tgx;
            atom_add(&forceBuffers[offset], (long) (force.x*0xFFFFFFFF));
            atom_add(&forceBuffers[offset+PADDED_NUM_ATOMS], (long) (force.y*0xFFFFFFFF));
            atom_add(&forceBuffers[offset+2*PADDED_NUM_ATOMS], (long) (force.z*0xFFFFFFFF));
        }
        if (pos < end && x != y) {
            const unsigned int offset = y*TILE_SIZE + tgx;
            atom_add(&forceBuffers[offset], (long) (localData[get_local_id(0)].fx*0xFFFFFFFF));
            atom_add(&forceBuffers[offset+PADDED_NUM_ATOMS], (long) (localData[get_local_id(0)].fy*0xFFFFFFFF));
            atom_add(&forceBuffers[offset+2*PADDED_NUM_ATOMS], (long) (localData[get_local_id(0)].fz*0xFFFFFFFF));
        }
#else
309
310
311
312
313
314
315
316
317
        int writeX = (pos < end ? x : -1);
        int writeY = (pos < end && x != y ? y : -1);
        if (tgx == 0)
            reservedBlocks[localGroupIndex] = (int2)(writeX, writeY);
        bool done = false;
        int doneIndex = 0;
        int checkIndex = 0;
        while (true) {
            // See if any warp still needs to write its data.
318

319
320
321
322
323
324
325
326
327
328
329
330
331
            bool allDone = true;
            barrier(CLK_LOCAL_MEM_FENCE);
            while (doneIndex < WARPS_PER_GROUP && allDone) {
                if (reservedBlocks[doneIndex].x != -1)
                    allDone = false;
                else
                    doneIndex++;
            }
            if (allDone)
                break;
            if (!done) {
                // See whether this warp can write its data.  This requires that no previous warp
                // is trying to write to the same block of the buffer.
332

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
                bool canWrite = (writeX != -1);
                while (checkIndex < localGroupIndex && canWrite) {
                    if ((reservedBlocks[checkIndex].x == x || reservedBlocks[checkIndex].y == x) ||
                            (writeY != -1 && (reservedBlocks[checkIndex].x == y || reservedBlocks[checkIndex].y == y)))
                        canWrite = false;
                    else
                        checkIndex++;
                }
                if (canWrite) {
                    // Write the data to global memory, then mark this warp as done.

                    if (writeX > -1) {
                        const unsigned int offset = x*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
                        forceBuffers[offset].xyz += force.xyz;
                    }
                    if (writeY > -1) {
                        const unsigned int offset = y*TILE_SIZE + tgx + get_group_id(0)*PADDED_NUM_ATOMS;
350
                        forceBuffers[offset] += (real4) (localData[get_local_id(0)].fx, localData[get_local_id(0)].fy, localData[get_local_id(0)].fz, 0);
351
352
353
354
355
356
                    }
                    done = true;
                    if (tgx == 0)
                        reservedBlocks[localGroupIndex] = (int2)(-1, -1);
                }
            }
357
        }
358
#endif
359
        pos++;
360
    } while (pos < end);
361
362
    energyBuffer[get_global_id(0)] += energy;
}