customGBValueN2.cu 12.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
typedef struct {
    real4 posq;
    real value, temp;
    ATOM_PARAMETER_DATA
#ifdef NEED_PADDING
    float padding;
#endif
} AtomData;

/**
 * Compute a value based on pair interactions.
 */
extern "C" __global__ void computeN2Value(const real4* __restrict__ posq, const unsigned int* __restrict__ exclusions,
14
        const ushort2* __restrict__ exclusionTiles, unsigned long long* __restrict__ global_value,
15
#ifdef USE_CUTOFF
16
        const ushort2* __restrict__ tiles, const unsigned int* __restrict__ interactionCount, real4 periodicBoxSize, real4 invPeriodicBoxSize, unsigned int maxTiles, const real4* __restrict__ blockCenter, const unsigned int* __restrict__ interactingAtoms
17
18
19
20
#else
        unsigned int numTiles
#endif
        PARAMETER_ARGUMENTS) {
21
22
23
24
    const unsigned int totalWarps = (blockDim.x*gridDim.x)/TILE_SIZE;
    const unsigned int warp = (blockIdx.x*blockDim.x+threadIdx.x)/TILE_SIZE;
    const unsigned int tgx = threadIdx.x & (TILE_SIZE-1);
    const unsigned int tbx = threadIdx.x - tgx;
25
    __shared__ AtomData localData[THREAD_BLOCK_SIZE];
26
27

    // First loop: process tiles that contain exclusions.
28
    
29
30
31
32
33
34
    const unsigned int firstExclusionTile = FIRST_EXCLUSION_TILE+warp*(LAST_EXCLUSION_TILE-FIRST_EXCLUSION_TILE)/totalWarps;
    const unsigned int lastExclusionTile = FIRST_EXCLUSION_TILE+(warp+1)*(LAST_EXCLUSION_TILE-FIRST_EXCLUSION_TILE)/totalWarps;
    for (int pos = firstExclusionTile; pos < lastExclusionTile; pos++) {
        const ushort2 tileIndices = exclusionTiles[pos];
        const unsigned int x = tileIndices.x;
        const unsigned int y = tileIndices.y;
35
        real value = 0;
36
37
38
39
40
        unsigned int atom1 = x*TILE_SIZE + tgx;
        real4 posq1 = posq[atom1];
        LOAD_ATOM1_PARAMETERS
#ifdef USE_EXCLUSIONS
        unsigned int excl = exclusions[pos*TILE_SIZE+tgx];
41
#endif
42
43
        if (x == y) {
            // This tile is on the diagonal.
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
            const unsigned int localAtomIndex = threadIdx.x;
            localData[localAtomIndex].posq = posq1;
            LOAD_LOCAL_PARAMETERS_FROM_1
            for (unsigned int j = 0; j < TILE_SIZE; j++) {
                int atom2 = tbx+j;
                real4 posq2 = localData[atom2].posq;
                real3 delta = make_real3(posq2.x-posq1.x, posq2.y-posq1.y, posq2.z-posq1.z);
#ifdef USE_PERIODIC
                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
#endif
                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
                if (r2 < CUTOFF_SQUARED) {
#endif
                    real invR = RSQRT(r2);
                    real r = RECIP(invR);
                    LOAD_ATOM2_PARAMETERS
                    atom2 = y*TILE_SIZE+j;
                    real tempValue1 = 0;
                    real tempValue2 = 0;
67
#ifdef USE_EXCLUSIONS
68
69
                    bool isExcluded = (atom1 >= NUM_ATOMS || atom2 >= NUM_ATOMS || !(excl & 0x1));
                    if (!isExcluded && atom1 != atom2) {
70
#else
71
72
73
74
75
76
77
                    if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS && atom1 != atom2) {
#endif
                        COMPUTE_VALUE
                    }
                    value += tempValue1;
#ifdef USE_CUTOFF
                }
78
79
#endif
#ifdef USE_EXCLUSIONS
80
                excl >>= 1;
81
#endif
82
83
84
85
86
87
88
89
90
91
            }
        }
        else {
            // This is an off-diagonal tile.

            const unsigned int localAtomIndex = threadIdx.x;
            unsigned int j = y*TILE_SIZE + tgx;
            localData[localAtomIndex].posq = posq[j];
            LOAD_LOCAL_PARAMETERS_FROM_GLOBAL
            localData[localAtomIndex].value = 0;
92
#ifdef USE_EXCLUSIONS
93
            excl = (excl >> tgx) | (excl << (TILE_SIZE - tgx));
94
#endif
95
96
97
98
99
            unsigned int tj = tgx;
            for (j = 0; j < TILE_SIZE; j++) {
                int atom2 = tbx+tj;
                real4 posq2 = localData[atom2].posq;
                real3 delta = make_real3(posq2.x-posq1.x, posq2.y-posq1.y, posq2.z-posq1.z);
100
#ifdef USE_PERIODIC
101
102
103
                delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
104
#endif
105
                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
106
#ifdef USE_CUTOFF
107
                if (r2 < CUTOFF_SQUARED) {
108
109
110
111
#endif
                    real invR = RSQRT(r2);
                    real r = RECIP(invR);
                    LOAD_ATOM2_PARAMETERS
112
                    atom2 = y*TILE_SIZE+tj;
113
114
115
                    real tempValue1 = 0;
                    real tempValue2 = 0;
#ifdef USE_EXCLUSIONS
116
117
                    bool isExcluded = (atom1 >= NUM_ATOMS || atom2 >= NUM_ATOMS || !(excl & 0x1));
                    if (!isExcluded) {
118
#else
119
                    if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
120
121
122
123
#endif
                        COMPUTE_VALUE
                    }
                    value += tempValue1;
124
                    localData[tbx+tj].value += tempValue2;
125
#ifdef USE_CUTOFF
126
                }
127
128
#endif
#ifdef USE_EXCLUSIONS
129
                excl >>= 1;
130
#endif
131
                tj = (tj + 1) & (TILE_SIZE - 1);
132
            }
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        }

        // Write results.

        unsigned int offset = x*TILE_SIZE + tgx;
        atomicAdd(&global_value[offset], static_cast<unsigned long long>((long long) (value*0x100000000)));
        if (x != y) {
            offset = y*TILE_SIZE + tgx;
            atomicAdd(&global_value[offset], static_cast<unsigned long long>((long long) (localData[threadIdx.x].value*0x100000000)));
        }
    }

    // Second loop: tiles without exclusions, either from the neighbor list (with cutoff) or just enumerating all
    // of them (no cutoff).
147
148

#ifdef USE_CUTOFF
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    unsigned int numTiles = interactionCount[0];
    int pos = warp*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
    int end = (warp+1)*(numTiles > maxTiles ? NUM_BLOCKS*(NUM_BLOCKS+1)/2 : numTiles)/totalWarps;
#else
    int pos = warp*numTiles/totalWarps;
    int end = (warp+1)*numTiles/totalWarps;
#endif
    int skipBase = 0;
    int currentSkipIndex = tbx;
    __shared__ int atomIndices[THREAD_BLOCK_SIZE];
    __shared__ int skipTiles[THREAD_BLOCK_SIZE];
    skipTiles[threadIdx.x] = -1;
    
    while (pos < end) {
        real value = 0;
        bool includeTile = true;
        
        // Extract the coordinates of this tile.
        
        unsigned int x, y;
        bool singlePeriodicCopy = false;
#ifdef USE_CUTOFF
        if (numTiles <= maxTiles) {
            ushort2 tileIndices = tiles[pos];
            x = tileIndices.x;
            singlePeriodicCopy = tileIndices.y;
        }
        else
#endif
        {
            y = (unsigned int) floor(NUM_BLOCKS+0.5f-SQRT((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
            x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
            if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                y += (x < y ? -1 : 1);
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
            }
185

186
            // Skip over tiles that have exclusions, since they were already processed.
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
            while (skipTiles[tbx+TILE_SIZE-1] < pos) {
                if (skipBase+tgx < NUM_TILES_WITH_EXCLUSIONS) {
                    ushort2 tile = exclusionTiles[skipBase+tgx];
                    skipTiles[threadIdx.x] = tile.x + tile.y*NUM_BLOCKS - tile.y*(tile.y+1)/2;
                }
                else
                    skipTiles[threadIdx.x] = end;
                skipBase += TILE_SIZE;            
                currentSkipIndex = tbx;
            }
            while (skipTiles[currentSkipIndex] < pos)
                currentSkipIndex++;
            includeTile = (skipTiles[currentSkipIndex] != pos);
        }
        if (includeTile) {
            unsigned int atom1 = x*TILE_SIZE + tgx;

            // Load atom data for this tile.
            
            real4 posq1 = posq[atom1];
            LOAD_ATOM1_PARAMETERS
            const unsigned int localAtomIndex = threadIdx.x;
#ifdef USE_CUTOFF
            unsigned int j = (numTiles <= maxTiles ? interactingAtoms[pos*TILE_SIZE+tgx] : y*TILE_SIZE + tgx);
#else
            unsigned int j = y*TILE_SIZE + tgx;
#endif
            atomIndices[threadIdx.x] = j;
            if (j < PADDED_NUM_ATOMS) {
                localData[localAtomIndex].posq = posq[j];
                LOAD_LOCAL_PARAMETERS_FROM_GLOBAL
                localData[localAtomIndex].value = 0;
            }
#ifdef USE_PERIODIC
            if (singlePeriodicCopy) {
                // The box is small enough that we can just translate all the atoms into a single periodic
                // box, then skip having to apply periodic boundary conditions later.
225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
                real4 blockCenterX = blockCenter[x];
                posq1.x -= floor((posq1.x-blockCenterX.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                posq1.y -= floor((posq1.y-blockCenterX.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                posq1.z -= floor((posq1.z-blockCenterX.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
                localData[threadIdx.x].posq.x -= floor((localData[threadIdx.x].posq.x-blockCenterX.x)*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                localData[threadIdx.x].posq.y -= floor((localData[threadIdx.x].posq.y-blockCenterX.y)*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                localData[threadIdx.x].posq.z -= floor((localData[threadIdx.x].posq.z-blockCenterX.z)*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
                unsigned int tj = tgx;
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
                    int atom2 = tbx+tj;
                    real4 posq2 = localData[atom2].posq;
                    real3 delta = make_real3(posq2.x-posq1.x, posq2.y-posq1.y, posq2.z-posq1.z);
                    real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                    if (r2 < CUTOFF_SQUARED) {
                        real invR = RSQRT(r2);
                        real r = RECIP(invR);
                        LOAD_ATOM2_PARAMETERS
                        atom2 = atomIndices[tbx+tj];
                        real tempValue1 = 0;
                        real tempValue2 = 0;
                        if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
                            COMPUTE_VALUE
248
                        }
249
250
                        value += tempValue1;
                        localData[tbx+tj].value += tempValue2;
251
                    }
252
                    tj = (tj + 1) & (TILE_SIZE - 1);
253
                }
254
255
            }
            else
256
#endif
257
258
            {
                // We need to apply periodic boundary conditions separately for each interaction.
259

260
261
262
263
264
                unsigned int tj = tgx;
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
                    int atom2 = tbx+tj;
                    real4 posq2 = localData[atom2].posq;
                    real3 delta = make_real3(posq2.x-posq1.x, posq2.y-posq1.y, posq2.z-posq1.z);
265
#ifdef USE_PERIODIC
266
267
268
                    delta.x -= floor(delta.x*invPeriodicBoxSize.x+0.5f)*periodicBoxSize.x;
                    delta.y -= floor(delta.y*invPeriodicBoxSize.y+0.5f)*periodicBoxSize.y;
                    delta.z -= floor(delta.z*invPeriodicBoxSize.z+0.5f)*periodicBoxSize.z;
269
#endif
270
                    real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
271
#ifdef USE_CUTOFF
272
                    if (r2 < CUTOFF_SQUARED) {
273
274
275
276
#endif
                        real invR = RSQRT(r2);
                        real r = RECIP(invR);
                        LOAD_ATOM2_PARAMETERS
277
                        atom2 = atomIndices[tbx+tj];
278
279
280
281
282
283
284
285
286
                        real tempValue1 = 0;
                        real tempValue2 = 0;
                        if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
                            COMPUTE_VALUE
                        }
                        value += tempValue1;
                        localData[tbx+tj].value += tempValue2;
#ifdef USE_CUTOFF
                    }
287
288
#endif
                    tj = (tj + 1) & (TILE_SIZE - 1);
289
290
291
                }
            }
        
292
293
294
295
296
297
298
299
300
301
            // Write results.

            atomicAdd(&global_value[atom1], static_cast<unsigned long long>((long long) (value*0x100000000)));
#ifdef USE_CUTOFF
            unsigned int atom2 = atomIndices[threadIdx.x];
#else
            unsigned int atom2 = y*TILE_SIZE + tgx;
#endif
            if (atom2 < PADDED_NUM_ATOMS)
                atomicAdd(&global_value[atom2], static_cast<unsigned long long>((long long) (localData[threadIdx.x].value*0x100000000)));
302
303
        }
        pos++;
304
    }
305
}