customGBEnergyN2.cu 15.8 KB
Newer Older
1
2
#define STORE_DERIVATIVE_1(INDEX) atomicAdd(&derivBuffers[offset+(INDEX-1)*PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (deriv##INDEX##_1*0x100000000)));
#define STORE_DERIVATIVE_2(INDEX) atomicAdd(&derivBuffers[offset+(INDEX-1)*PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (localData[threadIdx.x].deriv##INDEX*0x100000000)));
3
4
5

typedef struct {
    real4 posq;
6
    real3 force;
7
8
9
10
11
12
13
14
15
16
    ATOM_PARAMETER_DATA
#ifdef NEED_PADDING
    float padding;
#endif
} AtomData;

/**
 * Compute a force based on pair interactions.
 */
extern "C" __global__ void computeN2Energy(unsigned long long* __restrict__ forceBuffers, real* __restrict__ energyBuffer,
17
        const real4* __restrict__ posq, const unsigned int* __restrict__ exclusions, const ushort2* __restrict__ exclusionTiles,
18
#ifdef USE_CUTOFF
19
        const ushort2* __restrict__ tiles, const unsigned int* __restrict__ interactionCount, real4 periodicBoxSize, real4 invPeriodicBoxSize, unsigned int maxTiles, const real4* __restrict__ blockCenter, const unsigned int* __restrict__ interactingAtoms
20
21
22
23
#else
        unsigned int numTiles
#endif
        PARAMETER_ARGUMENTS) {
24
25
26
27
    const unsigned int totalWarps = (blockDim.x*gridDim.x)/TILE_SIZE;
    const unsigned int warp = (blockIdx.x*blockDim.x+threadIdx.x)/TILE_SIZE;
    const unsigned int tgx = threadIdx.x & (TILE_SIZE-1);
    const unsigned int tbx = threadIdx.x - tgx;
28
29
    real energy = 0;
    __shared__ AtomData localData[THREAD_BLOCK_SIZE];
30
31

    // First loop: process tiles that contain exclusions.
32
    
33
34
35
36
37
38
    const unsigned int firstExclusionTile = FIRST_EXCLUSION_TILE+warp*(LAST_EXCLUSION_TILE-FIRST_EXCLUSION_TILE)/totalWarps;
    const unsigned int lastExclusionTile = FIRST_EXCLUSION_TILE+(warp+1)*(LAST_EXCLUSION_TILE-FIRST_EXCLUSION_TILE)/totalWarps;
    for (int pos = firstExclusionTile; pos < lastExclusionTile; pos++) {
        const ushort2 tileIndices = exclusionTiles[pos];
        const unsigned int x = tileIndices.x;
        const unsigned int y = tileIndices.y;
39
        real3 force = make_real3(0);
40
        DECLARE_ATOM1_DERIVATIVES
41
42
43
44
45
        unsigned int atom1 = x*TILE_SIZE + tgx;
        real4 posq1 = posq[atom1];
        LOAD_ATOM1_PARAMETERS
#ifdef USE_EXCLUSIONS
        unsigned int excl = exclusions[pos*TILE_SIZE+tgx];
46
#endif
47
48
        if (x == y) {
            // This tile is on the diagonal.
49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
            const unsigned int localAtomIndex = threadIdx.x;
            localData[localAtomIndex].posq = posq1;
            LOAD_LOCAL_PARAMETERS_FROM_1
            for (unsigned int j = 0; j < TILE_SIZE; j++) {
                int atom2 = tbx+j;
                real4 posq2 = localData[atom2].posq;
                real3 delta = make_real3(posq2.x-posq1.x, posq2.y-posq1.y, posq2.z-posq1.z);
#ifdef USE_PERIODIC
                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
                if (r2 < CUTOFF_SQUARED) {
#endif
                    real invR = RSQRT(r2);
                    real r = RECIP(invR);
                    LOAD_ATOM2_PARAMETERS
                    atom2 = y*TILE_SIZE+j;
                    real dEdR = 0;
                    real tempEnergy = 0;
72
#ifdef USE_EXCLUSIONS
73
74
75
76
77
78
79
80
81
82
83
84
85
                    bool isExcluded = !(excl & 0x1);
#endif
                    if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS && atom1 != atom2) {
                        COMPUTE_INTERACTION
                        dEdR /= -r;
                    }
                    energy += 0.5f*tempEnergy;
                    delta *= dEdR;
                    force.x -= delta.x;
                    force.y -= delta.y;
                    force.z -= delta.z;
#ifdef USE_CUTOFF
                }
86
87
#endif
#ifdef USE_EXCLUSIONS
88
                excl >>= 1;
89
#endif
90
91
92
93
94
95
96
97
98
99
100
            }
        }
        else {
            // This is an off-diagonal tile.

            const unsigned int localAtomIndex = threadIdx.x;
            unsigned int j = y*TILE_SIZE + tgx;
            localData[localAtomIndex].posq = posq[j];
            LOAD_LOCAL_PARAMETERS_FROM_GLOBAL
            localData[localAtomIndex].force = make_real3(0);
            CLEAR_LOCAL_DERIVATIVES
101
#ifdef USE_EXCLUSIONS
102
            excl = (excl >> tgx) | (excl << (TILE_SIZE - tgx));
103
#endif
104
105
106
107
108
            unsigned int tj = tgx;
            for (j = 0; j < TILE_SIZE; j++) {
                int atom2 = tbx+tj;
                real4 posq2 = localData[atom2].posq;
                real3 delta = make_real3(posq2.x-posq1.x, posq2.y-posq1.y, posq2.z-posq1.z);
109
#ifdef USE_PERIODIC
110
111
112
                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
113
#endif
114
                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
115
#ifdef USE_CUTOFF
116
                if (r2 < CUTOFF_SQUARED) {
117
118
119
120
#endif
                    real invR = RSQRT(r2);
                    real r = RECIP(invR);
                    LOAD_ATOM2_PARAMETERS
121
                    atom2 = y*TILE_SIZE+tj;
122
123
                    real dEdR = 0;
                    real tempEnergy = 0;
124
125
126
127
#ifdef USE_EXCLUSIONS
                    bool isExcluded = !(excl & 0x1);
#endif
                    if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
128
129
130
                        COMPUTE_INTERACTION
                        dEdR /= -r;
                    }
131
                    energy += tempEnergy;
132
                    delta *= dEdR;
133
134
135
                    force.x -= delta.x;
                    force.y -= delta.y;
                    force.z -= delta.z;
136
137
138
139
140
                    atom2 = tbx+tj;
                    localData[atom2].force.x += delta.x;
                    localData[atom2].force.y += delta.y;
                    localData[atom2].force.z += delta.z;
                    RECORD_DERIVATIVE_2
141
#ifdef USE_CUTOFF
142
                }
143
144
#endif
#ifdef USE_EXCLUSIONS
145
                excl >>= 1;
146
#endif
147
                tj = (tj + 1) & (TILE_SIZE - 1);
148
            }
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        }

        // Write results.

        unsigned int offset = x*TILE_SIZE + tgx;
        atomicAdd(&forceBuffers[offset], static_cast<unsigned long long>((long long) (force.x*0x100000000)));
        atomicAdd(&forceBuffers[offset+PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (force.y*0x100000000)));
        atomicAdd(&forceBuffers[offset+2*PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (force.z*0x100000000)));
        STORE_DERIVATIVES_1
        if (x != y) {
            offset = y*TILE_SIZE + tgx;
            atomicAdd(&forceBuffers[offset], static_cast<unsigned long long>((long long) (localData[threadIdx.x].force.x*0x100000000)));
            atomicAdd(&forceBuffers[offset+PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (localData[threadIdx.x].force.y*0x100000000)));
            atomicAdd(&forceBuffers[offset+2*PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (localData[threadIdx.x].force.z*0x100000000)));
            STORE_DERIVATIVES_2
        }
    }

    // Second loop: tiles without exclusions, either from the neighbor list (with cutoff) or just enumerating all
    // of them (no cutoff).

170
#ifdef USE_CUTOFF
171
172
173
174
175
176
177
178
179
180
    unsigned int numTiles = interactionCount[0];
    int pos = warp*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
    int end = (warp+1)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
#else
    int pos = warp*numTiles/totalWarps;
    int end = (warp+1)*numTiles/totalWarps;
#endif
    int skipBase = 0;
    int currentSkipIndex = tbx;
    __shared__ int atomIndices[THREAD_BLOCK_SIZE];
181
    __shared__ volatile int skipTiles[THREAD_BLOCK_SIZE];
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    skipTiles[threadIdx.x] = -1;
    
    while (pos < end) {
        const bool isExcluded = false;
        real3 force = make_real3(0);
        DECLARE_ATOM1_DERIVATIVES
        bool includeTile = true;
        
        // Extract the coordinates of this tile.
        
        unsigned int x, y;
        bool singlePeriodicCopy = false;
#ifdef USE_CUTOFF
        if (numTiles <= maxTiles) {
            ushort2 tileIndices = tiles[pos];
            x = tileIndices.x;
            singlePeriodicCopy = tileIndices.y;
        }
        else
#endif
        {
            y = (unsigned int) floor(NUM_BLOCKS+0.5f-SQRT((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
            x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
            if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                y += (x < y ? -1 : 1);
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
            }

            // Skip over tiles that have exclusions, since they were already processed.

            while (skipTiles[tbx+TILE_SIZE-1] < pos) {
                if (skipBase+tgx < NUM_TILES_WITH_EXCLUSIONS) {
                    ushort2 tile = exclusionTiles[skipBase+tgx];
                    skipTiles[threadIdx.x] = tile.x + tile.y*NUM_BLOCKS - tile.y*(tile.y+1)/2;
216
217
                }
                else
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
                    skipTiles[threadIdx.x] = end;
                skipBase += TILE_SIZE;            
                currentSkipIndex = tbx;
            }
            while (skipTiles[currentSkipIndex] < pos)
                currentSkipIndex++;
            includeTile = (skipTiles[currentSkipIndex] != pos);
        }
        if (includeTile) {
            unsigned int atom1 = x*TILE_SIZE + tgx;

            // Load atom data for this tile.

            real4 posq1 = posq[atom1];
            LOAD_ATOM1_PARAMETERS
            const unsigned int localAtomIndex = threadIdx.x;
#ifdef USE_CUTOFF
            unsigned int j = (numTiles <= maxTiles ? interactingAtoms[pos*TILE_SIZE+tgx] : y*TILE_SIZE + tgx);
#else
            unsigned int j = y*TILE_SIZE + tgx;
238
#endif
239
240
241
242
243
244
245
246
247
248
249
            atomIndices[threadIdx.x] = j;
            if (j < PADDED_NUM_ATOMS) {
                localData[localAtomIndex].posq = posq[j];
                LOAD_LOCAL_PARAMETERS_FROM_GLOBAL
                localData[localAtomIndex].force = make_real3(0);
                CLEAR_LOCAL_DERIVATIVES
            }
#ifdef USE_PERIODIC
            if (singlePeriodicCopy) {
                // The box is small enough that we can just translate all the atoms into a single periodic
                // box, then skip having to apply periodic boundary conditions later.
250

251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
                real4 blockCenterX = blockCenter[x];
                posq1.x -= floor((posq1.x-blockCenterX.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                posq1.y -= floor((posq1.y-blockCenterX.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                posq1.z -= floor((posq1.z-blockCenterX.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
                localData[threadIdx.x].posq.x -= floor((localData[threadIdx.x].posq.x-blockCenterX.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                localData[threadIdx.x].posq.y -= floor((localData[threadIdx.x].posq.y-blockCenterX.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                localData[threadIdx.x].posq.z -= floor((localData[threadIdx.x].posq.z-blockCenterX.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
                unsigned int tj = tgx;
                for (j = 0; j < TILE_SIZE; j++) {
                    int atom2 = tbx+tj;
                    real4 posq2 = localData[atom2].posq;
                    real3 delta = make_real3(posq2.x-posq1.x, posq2.y-posq1.y, posq2.z-posq1.z);
                    real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
                    if (r2 < CUTOFF_SQUARED) {
266
#endif
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
                        real invR = RSQRT(r2);
                        real r = RECIP(invR);
                        LOAD_ATOM2_PARAMETERS
                        atom2 = atomIndices[tbx+tj];
                        real dEdR = 0;
                        real tempEnergy = 0;
                        if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
                            COMPUTE_INTERACTION
                            dEdR /= -r;
                        }
                        energy += tempEnergy;
                        delta *= dEdR;
                        force.x -= delta.x;
                        force.y -= delta.y;
                        force.z -= delta.z;
                        atom2 = tbx+tj;
                        localData[atom2].force.x += delta.x;
                        localData[atom2].force.y += delta.y;
                        localData[atom2].force.z += delta.z;
                        RECORD_DERIVATIVE_2
#ifdef USE_CUTOFF
                    }
289
#endif
290
291
292
293
294
295
296
297
298
299
300
301
302
                    tj = (tj + 1) & (TILE_SIZE - 1);
                }
            }
            else
#endif
            {
                // We need to apply periodic boundary conditions separately for each interaction.

                unsigned int tj = tgx;
                for (j = 0; j < TILE_SIZE; j++) {
                    int atom2 = tbx+tj;
                    real4 posq2 = localData[atom2].posq;
                    real3 delta = make_real3(posq2.x-posq1.x, posq2.y-posq1.y, posq2.z-posq1.z);
303
#ifdef USE_PERIODIC
304
305
306
                    delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                    delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                    delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
307
#endif
308
                    real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
309
#ifdef USE_CUTOFF
310
                    if (r2 < CUTOFF_SQUARED) {
311
312
313
314
#endif
                        real invR = RSQRT(r2);
                        real r = RECIP(invR);
                        LOAD_ATOM2_PARAMETERS
315
                        atom2 = atomIndices[tbx+tj];
316
317
318
319
320
321
322
323
                        real dEdR = 0;
                        real tempEnergy = 0;
                        if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
                            COMPUTE_INTERACTION
                            dEdR /= -r;
                        }
                        energy += tempEnergy;
                        delta *= dEdR;
324
325
326
                        force.x -= delta.x;
                        force.y -= delta.y;
                        force.z -= delta.z;
327
                        atom2 = tbx+tj;
328
329
330
                        localData[atom2].force.x += delta.x;
                        localData[atom2].force.y += delta.y;
                        localData[atom2].force.z += delta.z;
331
332
333
                        RECORD_DERIVATIVE_2
#ifdef USE_CUTOFF
                    }
334
335
#endif
                    tj = (tj + 1) & (TILE_SIZE - 1);
336
337
338
                }
            }
        
339
340
341
342
343
344
            // Write results.

            atomicAdd(&forceBuffers[atom1], static_cast<unsigned long long>((long long) (force.x*0x100000000)));
            atomicAdd(&forceBuffers[atom1+PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (force.y*0x100000000)));
            atomicAdd(&forceBuffers[atom1+2*PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (force.z*0x100000000)));
            unsigned int offset = atom1;
345
            STORE_DERIVATIVES_1
346
347
348
349
350
351
352
353
354
355
356
357
#ifdef USE_CUTOFF
            unsigned int atom2 = atomIndices[threadIdx.x];
#else
            unsigned int atom2 = y*TILE_SIZE + tgx;
#endif
            if (atom2 < PADDED_NUM_ATOMS) {
                atomicAdd(&forceBuffers[atom2], static_cast<unsigned long long>((long long) (localData[threadIdx.x].force.x*0x100000000)));
                atomicAdd(&forceBuffers[atom2+PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (localData[threadIdx.x].force.y*0x100000000)));
                atomicAdd(&forceBuffers[atom2+2*PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (localData[threadIdx.x].force.z*0x100000000)));
                offset = atom2;
                STORE_DERIVATIVES_2
            }
358
359
        }
        pos++;
360
    }
361
362
    energyBuffer[blockIdx.x*blockDim.x+threadIdx.x] += energy;
}