customGBValueN2.cc 12.9 KB
Newer Older
1
2
3
/**
 * Compute a value based on pair interactions.
 */
4
5
KERNEL void computeN2Value(GLOBAL const real4* RESTRICT posq, GLOBAL const unsigned int* RESTRICT exclusions,
        GLOBAL const ushort2* exclusionTiles,
6
#ifdef SUPPORTS_64_BIT_ATOMICS
7
        GLOBAL mm_ulong* RESTRICT global_value,
8
#else
9
        GLOBAL real* RESTRICT global_value,
10
11
#endif
#ifdef USE_CUTOFF
12
13
14
        GLOBAL const int* RESTRICT tiles, GLOBAL const unsigned int* RESTRICT interactionCount, real4 periodicBoxSize, real4 invPeriodicBoxSize,
        real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ, unsigned int maxTiles, GLOBAL const real4* RESTRICT blockCenter,
        GLOBAL const real4* RESTRICT blockSize, GLOBAL const int* RESTRICT interactingAtoms
15
16
17
18
#else
        unsigned int numTiles
#endif
        PARAMETER_ARGUMENTS) {
19
20
21
22
23
24
25
    const unsigned int totalWarps = GLOBAL_SIZE/TILE_SIZE;
    const unsigned int warp = GLOBAL_ID/TILE_SIZE;
    const unsigned int tgx = LOCAL_ID & (TILE_SIZE-1);
    const unsigned int tbx = LOCAL_ID - tgx;
    LOCAL real3 local_pos[LOCAL_BUFFER_SIZE];
    LOCAL real local_value[LOCAL_BUFFER_SIZE];
    ATOM_PARAMETER_DATA
26
27
28

    // First loop: process tiles that contain exclusions.
    
29
30
    const int firstExclusionTile = FIRST_EXCLUSION_TILE+warp*(LAST_EXCLUSION_TILE-FIRST_EXCLUSION_TILE)/totalWarps;
    const int lastExclusionTile = FIRST_EXCLUSION_TILE+(warp+1)*(LAST_EXCLUSION_TILE-FIRST_EXCLUSION_TILE)/totalWarps;
31
32
33
34
35
36
    for (int pos = firstExclusionTile; pos < lastExclusionTile; pos++) {
        const ushort2 tileIndices = exclusionTiles[pos];
        const unsigned int x = tileIndices.x;
        const unsigned int y = tileIndices.y;
        real value = 0;
        unsigned int atom1 = x*TILE_SIZE + tgx;
37
        real3 pos1 = trimTo3(posq[atom1]);
38
39
40
41
42
43
44
        LOAD_ATOM1_PARAMETERS
#ifdef USE_EXCLUSIONS
        unsigned int excl = exclusions[pos*TILE_SIZE+tgx];
#endif
        if (x == y) {
            // This tile is on the diagonal.

45
46
            const unsigned int localAtomIndex = LOCAL_ID;
            local_pos[localAtomIndex] = pos1;
47
48
49
50
            LOAD_LOCAL_PARAMETERS_FROM_1
            SYNC_WARPS;
            for (unsigned int j = 0; j < TILE_SIZE; j++) {
                int atom2 = tbx+j;
51
52
                real3 pos2 = local_pos[atom2];
                real3 delta = make_real3(pos2.x-pos1.x, pos2.y-pos1.y, pos2.z-pos1.z);
53
#ifdef USE_PERIODIC
54
                APPLY_PERIODIC_TO_DELTA(delta)
55
56
57
58
59
60
#endif
                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
                if (r2 < CUTOFF_SQUARED) {
#endif
                    real invR = RSQRT(r2);
peastman's avatar
peastman committed
61
                    real r = r2*invR;
62
63
64
65
66
67
68
69
70
71
72
73
74
                    LOAD_ATOM2_PARAMETERS
                    atom2 = y*TILE_SIZE+j;
                    real tempValue1 = 0;
                    real tempValue2 = 0;
#ifdef USE_EXCLUSIONS
                    bool isExcluded = (atom1 >= NUM_ATOMS || atom2 >= NUM_ATOMS || !(excl & 0x1));
                    if (!isExcluded && atom1 != atom2) {
#else
                    if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS && atom1 != atom2) {
#endif
                        COMPUTE_VALUE
                    }
                    value += tempValue1;
75
                    ADD_TEMP_DERIVS1
76
77
78
79
80
81
82
83
84
85
86
87
#ifdef USE_CUTOFF
                }
#endif
#ifdef USE_EXCLUSIONS
                excl >>= 1;
#endif
                SYNC_WARPS;
            }
        }
        else {
            // This is an off-diagonal tile.

88
            const unsigned int localAtomIndex = LOCAL_ID;
89
            unsigned int j = y*TILE_SIZE + tgx;
90
            local_pos[localAtomIndex] = trimTo3(posq[j]);
91
92
93
94
95
96
97
98
99
            LOAD_LOCAL_PARAMETERS_FROM_GLOBAL
            local_value[localAtomIndex] = 0;
            SYNC_WARPS;
#ifdef USE_EXCLUSIONS
            excl = (excl >> tgx) | (excl << (TILE_SIZE - tgx));
#endif
            unsigned int tj = tgx;
            for (j = 0; j < TILE_SIZE; j++) {
                int atom2 = tbx+tj;
100
101
                real3 pos2 = local_pos[atom2];
                real3 delta = make_real3(pos2.x-pos1.x, pos2.y-pos1.y, pos2.z-pos1.z);
102
#ifdef USE_PERIODIC
103
                APPLY_PERIODIC_TO_DELTA(delta)
104
105
106
107
108
109
#endif
                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
                if (r2 < CUTOFF_SQUARED) {
#endif
                    real invR = RSQRT(r2);
peastman's avatar
peastman committed
110
                    real r = r2*invR;
111
112
113
114
115
116
117
118
119
120
121
122
123
124
                    LOAD_ATOM2_PARAMETERS
                    atom2 = y*TILE_SIZE+tj;
                    real tempValue1 = 0;
                    real tempValue2 = 0;
#ifdef USE_EXCLUSIONS
                    bool isExcluded = (atom1 >= NUM_ATOMS || atom2 >= NUM_ATOMS || !(excl & 0x1));
                    if (!isExcluded) {
#else
                    if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
#endif
                        COMPUTE_VALUE
                    }
                    value += tempValue1;
                    local_value[tbx+tj] += tempValue2;
125
126
                    ADD_TEMP_DERIVS1
                    ADD_TEMP_DERIVS2
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#ifdef USE_CUTOFF
                }
#endif
#ifdef USE_EXCLUSIONS
                excl >>= 1;
#endif
                tj = (tj + 1) & (TILE_SIZE - 1);
                SYNC_WARPS;
            }
        }

        // Write results.

#ifdef SUPPORTS_64_BIT_ATOMICS
141
        unsigned int offset1 = x*TILE_SIZE + tgx;
142
        ATOMIC_ADD(&global_value[offset1], (mm_ulong) ((mm_long) (value*0x100000000)));
143
        STORE_PARAM_DERIVS1
144
        if (x != y) {
145
            unsigned int offset2 = y*TILE_SIZE + tgx;
146
            ATOMIC_ADD(&global_value[offset2], (mm_ulong) ((mm_long) (local_value[LOCAL_ID]*0x100000000)));
147
            STORE_PARAM_DERIVS2
148
149
150
151
152
        }
#else
        unsigned int offset1 = x*TILE_SIZE + tgx + warp*PADDED_NUM_ATOMS;
        unsigned int offset2 = y*TILE_SIZE + tgx + warp*PADDED_NUM_ATOMS;
        global_value[offset1] += value;
153
154
        STORE_PARAM_DERIVS1
        if (x != y) {
155
            global_value[offset2] += local_value[LOCAL_ID];
156
157
            STORE_PARAM_DERIVS2
        }
158
159
160
161
162
163
164
165
#endif
    }

    // Second loop: tiles without exclusions, either from the neighbor list (with cutoff) or just enumerating all
    // of them (no cutoff).

#ifdef USE_CUTOFF
    unsigned int numTiles = interactionCount[0];
166
167
    if (numTiles > maxTiles)
        return; // There wasn't enough memory for the neighbor list.
168
169
    int pos = (int) (warp*(numTiles > maxTiles ? NUM_BLOCKS*((mm_long)NUM_BLOCKS+1)/2 : (mm_long)numTiles)/totalWarps);
    int end = (int) ((warp+1)*(numTiles > maxTiles ? NUM_BLOCKS*((mm_long)NUM_BLOCKS+1)/2 : (mm_long)numTiles)/totalWarps);
170
#else
171
172
    int pos = (int) (warp*(mm_long)numTiles/totalWarps);
    int end = (int) ((warp+1)*(mm_long)numTiles/totalWarps);
173
174
175
#endif
    int skipBase = 0;
    int currentSkipIndex = tbx;
176
177
178
    LOCAL int atomIndices[LOCAL_BUFFER_SIZE];
    LOCAL volatile int skipTiles[LOCAL_BUFFER_SIZE];
    skipTiles[LOCAL_ID] = -1;
179
180
181
182
183
184
185

    while (pos < end) {
        real value = 0;
        bool includeTile = true;

        // Extract the coordinates of this tile.
        
186
        int x, y;
187
188
        bool singlePeriodicCopy = false;
#ifdef USE_CUTOFF
189
190
191
192
193
194
195
196
197
198
        x = tiles[pos];
        real4 blockSizeX = blockSize[x];
        singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= CUTOFF &&
                              0.5f*periodicBoxSize.y-blockSizeX.y >= CUTOFF &&
                              0.5f*periodicBoxSize.z-blockSizeX.z >= CUTOFF);
#else
        y = (int) floor(NUM_BLOCKS+0.5f-SQRT((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
        x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
        if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
            y += (x < y ? -1 : 1);
199
            x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
200
        }
201

202
        // Skip over tiles that have exclusions, since they were already processed.
203

204
205
        SYNC_WARPS;
        while (skipTiles[tbx+TILE_SIZE-1] < pos) {
206
            SYNC_WARPS;
207
208
            if (skipBase+tgx < NUM_TILES_WITH_EXCLUSIONS) {
                ushort2 tile = exclusionTiles[skipBase+tgx];
209
                skipTiles[LOCAL_ID] = tile.x + tile.y*NUM_BLOCKS - tile.y*(tile.y+1)/2;
210
            }
211
            else
212
                skipTiles[LOCAL_ID] = end;
213
214
215
            skipBase += TILE_SIZE;            
            currentSkipIndex = tbx;
            SYNC_WARPS;
216
        }
217
218
219
220
        while (skipTiles[currentSkipIndex] < pos)
            currentSkipIndex++;
        includeTile = (skipTiles[currentSkipIndex] != pos);
#endif
221
222
223
224
225
        if (includeTile) {
            unsigned int atom1 = x*TILE_SIZE + tgx;

            // Load atom data for this tile.
            
226
            real3 pos1 = trimTo3(posq[atom1]);
227
            LOAD_ATOM1_PARAMETERS
228
            const unsigned int localAtomIndex = LOCAL_ID;
229
#ifdef USE_CUTOFF
peastman's avatar
peastman committed
230
            unsigned int j = interactingAtoms[pos*TILE_SIZE+tgx];
231
232
233
#else
            unsigned int j = y*TILE_SIZE + tgx;
#endif
234
            atomIndices[LOCAL_ID] = j;
235
            if (j < PADDED_NUM_ATOMS) {
236
                local_pos[localAtomIndex] = trimTo3(posq[j]);
237
238
239
240
241
242
243
244
245
246
                LOAD_LOCAL_PARAMETERS_FROM_GLOBAL
                local_value[localAtomIndex] = 0;
            }
            SYNC_WARPS;
#ifdef USE_PERIODIC
            if (singlePeriodicCopy) {
                // The box is small enough that we can just translate all the atoms into a single periodic
                // box, then skip having to apply periodic boundary conditions later.

                real4 blockCenterX = blockCenter[x];
247
248
                APPLY_PERIODIC_TO_POS_WITH_CENTER(pos1, blockCenterX)
                APPLY_PERIODIC_TO_POS_WITH_CENTER(local_pos[LOCAL_ID], blockCenterX)
249
250
251
252
                SYNC_WARPS;
                unsigned int tj = tgx;
                for (j = 0; j < TILE_SIZE; j++) {
                    int atom2 = tbx+tj;
253
254
                    real3 pos2 = local_pos[atom2];
                    real3 delta = make_real3(pos2.x-pos1.x, pos2.y-pos1.y, pos2.z-pos1.z);
255
256
257
                    real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                    if (r2 < CUTOFF_SQUARED) {
                        real invR = RSQRT(r2);
peastman's avatar
peastman committed
258
                        real r = r2*invR;
259
260
261
262
263
264
265
266
267
                        LOAD_ATOM2_PARAMETERS
                        atom2 = atomIndices[tbx+tj];
                        real tempValue1 = 0;
                        real tempValue2 = 0;
                        if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
                            COMPUTE_VALUE
                        }
                        value += tempValue1;
                        local_value[tbx+tj] += tempValue2;
268
269
                        ADD_TEMP_DERIVS1
                        ADD_TEMP_DERIVS2
270
271
272
273
274
275
276
277
278
                    }
                    tj = (tj + 1) & (TILE_SIZE - 1);
                    SYNC_WARPS;
                }
            }
            else
#endif
            {
                // We need to apply periodic boundary conditions separately for each interaction.
279

280
281
282
                unsigned int tj = tgx;
                for (j = 0; j < TILE_SIZE; j++) {
                    int atom2 = tbx+tj;
283
284
                    real3 pos2 = local_pos[atom2];
                    real3 delta = make_real3(pos2.x-pos1.x, pos2.y-pos1.y, pos2.z-pos1.z);
285
#ifdef USE_PERIODIC
286
                    APPLY_PERIODIC_TO_DELTA(delta)
287
288
289
290
291
292
#endif
                    real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
                    if (r2 < CUTOFF_SQUARED) {
#endif
                        real invR = RSQRT(r2);
peastman's avatar
peastman committed
293
                        real r = r2*invR;
294
295
296
297
298
299
300
301
302
                        LOAD_ATOM2_PARAMETERS
                        atom2 = atomIndices[tbx+tj];
                        real tempValue1 = 0;
                        real tempValue2 = 0;
                        if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
                            COMPUTE_VALUE
                        }
                        value += tempValue1;
                        local_value[tbx+tj] += tempValue2;
303
304
                        ADD_TEMP_DERIVS1
                        ADD_TEMP_DERIVS2
305
306
307
308
309
310
311
312
313
#ifdef USE_CUTOFF
                    }
#endif
                    tj = (tj + 1) & (TILE_SIZE - 1);
                    SYNC_WARPS;
                }
            }
        
            // Write results.
314

315
#ifdef USE_CUTOFF
316
            unsigned int atom2 = atomIndices[LOCAL_ID];
317
318
319
320
#else
            unsigned int atom2 = y*TILE_SIZE + tgx;
#endif
#ifdef SUPPORTS_64_BIT_ATOMICS
321
            unsigned int offset1 = atom1;
322
            ATOMIC_ADD(&global_value[offset1], (mm_ulong) ((mm_long) (value*0x100000000)));
323
324
325
            STORE_PARAM_DERIVS1
            if (atom2 < PADDED_NUM_ATOMS) {
                unsigned int offset2 = atom2;
326
                ATOMIC_ADD(&global_value[offset2], (mm_ulong) ((mm_long) (local_value[LOCAL_ID]*0x100000000)));
327
328
                STORE_PARAM_DERIVS2
            }
329
330
331
#else
            unsigned int offset1 = atom1 + warp*PADDED_NUM_ATOMS;
            global_value[offset1] += value;
332
333
334
            STORE_PARAM_DERIVS1
            if (atom2 < PADDED_NUM_ATOMS) {
                unsigned int offset2 = atom2 + warp*PADDED_NUM_ATOMS;
335
                global_value[offset2] += local_value[LOCAL_ID];
336
337
                STORE_PARAM_DERIVS2
            }
338
339
340
341
342
#endif
        }
        pos++;
    }
}