customGBValueN2.cu 11.9 KB
Newer Older
1
typedef struct {
2
3
    real3 pos;
    real value;
4
5
6
7
8
9
10
11
12
13
    ATOM_PARAMETER_DATA
#ifdef NEED_PADDING
    float padding;
#endif
} AtomData;

/**
 * Compute a value based on pair interactions.
 */
extern "C" __global__ void computeN2Value(const real4* __restrict__ posq, const unsigned int* __restrict__ exclusions,
14
        const ushort2* __restrict__ exclusionTiles, unsigned long long* __restrict__ global_value,
15
#ifdef USE_CUTOFF
16
17
18
        const int* __restrict__ tiles, const unsigned int* __restrict__ interactionCount, real4 periodicBoxSize, real4 invPeriodicBoxSize,
        real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ, unsigned int maxTiles, const real4* __restrict__ blockCenter,
        const real4* __restrict__ blockSize, const unsigned int* __restrict__ interactingAtoms
19
20
21
22
#else
        unsigned int numTiles
#endif
        PARAMETER_ARGUMENTS) {
23
24
25
26
    const unsigned int totalWarps = (blockDim.x*gridDim.x)/TILE_SIZE;
    const unsigned int warp = (blockIdx.x*blockDim.x+threadIdx.x)/TILE_SIZE;
    const unsigned int tgx = threadIdx.x & (TILE_SIZE-1);
    const unsigned int tbx = threadIdx.x - tgx;
27
    __shared__ AtomData localData[THREAD_BLOCK_SIZE];
28
29

    // First loop: process tiles that contain exclusions.
30
    
31
32
33
34
35
36
    const unsigned int firstExclusionTile = FIRST_EXCLUSION_TILE+warp*(LAST_EXCLUSION_TILE-FIRST_EXCLUSION_TILE)/totalWarps;
    const unsigned int lastExclusionTile = FIRST_EXCLUSION_TILE+(warp+1)*(LAST_EXCLUSION_TILE-FIRST_EXCLUSION_TILE)/totalWarps;
    for (int pos = firstExclusionTile; pos < lastExclusionTile; pos++) {
        const ushort2 tileIndices = exclusionTiles[pos];
        const unsigned int x = tileIndices.x;
        const unsigned int y = tileIndices.y;
37
        real value = 0;
38
        unsigned int atom1 = x*TILE_SIZE + tgx;
39
        real4 pos1 = posq[atom1];
40
41
42
        LOAD_ATOM1_PARAMETERS
#ifdef USE_EXCLUSIONS
        unsigned int excl = exclusions[pos*TILE_SIZE+tgx];
43
#endif
44
45
        if (x == y) {
            // This tile is on the diagonal.
46

47
            const unsigned int localAtomIndex = threadIdx.x;
48
            localData[localAtomIndex].pos = make_real3(pos1.x, pos1.y, pos1.z);
49
50
51
            LOAD_LOCAL_PARAMETERS_FROM_1
            for (unsigned int j = 0; j < TILE_SIZE; j++) {
                int atom2 = tbx+j;
52
53
                real3 pos2 = localData[atom2].pos;
                real3 delta = make_real3(pos2.x-pos1.x, pos2.y-pos1.y, pos2.z-pos1.z);
54
#ifdef USE_PERIODIC
55
                APPLY_PERIODIC_TO_DELTA(delta)
56
57
58
59
60
61
#endif
                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
#ifdef USE_CUTOFF
                if (r2 < CUTOFF_SQUARED) {
#endif
                    real invR = RSQRT(r2);
peastman's avatar
peastman committed
62
                    real r = r2*invR;
63
64
65
66
                    LOAD_ATOM2_PARAMETERS
                    atom2 = y*TILE_SIZE+j;
                    real tempValue1 = 0;
                    real tempValue2 = 0;
67
#ifdef USE_EXCLUSIONS
68
69
                    bool isExcluded = (atom1 >= NUM_ATOMS || atom2 >= NUM_ATOMS || !(excl & 0x1));
                    if (!isExcluded && atom1 != atom2) {
70
#else
71
72
73
74
75
76
77
                    if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS && atom1 != atom2) {
#endif
                        COMPUTE_VALUE
                    }
                    value += tempValue1;
#ifdef USE_CUTOFF
                }
78
79
#endif
#ifdef USE_EXCLUSIONS
80
                excl >>= 1;
81
#endif
82
83
84
85
86
87
88
            }
        }
        else {
            // This is an off-diagonal tile.

            const unsigned int localAtomIndex = threadIdx.x;
            unsigned int j = y*TILE_SIZE + tgx;
89
90
            real4 tempPosq = posq[j];
            localData[localAtomIndex].pos = make_real3(tempPosq.x, tempPosq.y, tempPosq.z);
91
92
            LOAD_LOCAL_PARAMETERS_FROM_GLOBAL
            localData[localAtomIndex].value = 0;
93
#ifdef USE_EXCLUSIONS
94
            excl = (excl >> tgx) | (excl << (TILE_SIZE - tgx));
95
#endif
96
97
98
            unsigned int tj = tgx;
            for (j = 0; j < TILE_SIZE; j++) {
                int atom2 = tbx+tj;
99
100
                real3 pos2 = localData[atom2].pos;
                real3 delta = make_real3(pos2.x-pos1.x, pos2.y-pos1.y, pos2.z-pos1.z);
101
#ifdef USE_PERIODIC
102
                APPLY_PERIODIC_TO_DELTA(delta)
103
#endif
104
                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
105
#ifdef USE_CUTOFF
106
                if (r2 < CUTOFF_SQUARED) {
107
108
#endif
                    real invR = RSQRT(r2);
peastman's avatar
peastman committed
109
                    real r = r2*invR;
110
                    LOAD_ATOM2_PARAMETERS
111
                    atom2 = y*TILE_SIZE+tj;
112
113
114
                    real tempValue1 = 0;
                    real tempValue2 = 0;
#ifdef USE_EXCLUSIONS
115
116
                    bool isExcluded = (atom1 >= NUM_ATOMS || atom2 >= NUM_ATOMS || !(excl & 0x1));
                    if (!isExcluded) {
117
#else
118
                    if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
119
120
121
122
#endif
                        COMPUTE_VALUE
                    }
                    value += tempValue1;
123
                    localData[tbx+tj].value += tempValue2;
124
#ifdef USE_CUTOFF
125
                }
126
127
#endif
#ifdef USE_EXCLUSIONS
128
                excl >>= 1;
129
#endif
130
                tj = (tj + 1) & (TILE_SIZE - 1);
131
            }
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        }

        // Write results.

        unsigned int offset = x*TILE_SIZE + tgx;
        atomicAdd(&global_value[offset], static_cast<unsigned long long>((long long) (value*0x100000000)));
        if (x != y) {
            offset = y*TILE_SIZE + tgx;
            atomicAdd(&global_value[offset], static_cast<unsigned long long>((long long) (localData[threadIdx.x].value*0x100000000)));
        }
    }

    // Second loop: tiles without exclusions, either from the neighbor list (with cutoff) or just enumerating all
    // of them (no cutoff).
146
147

#ifdef USE_CUTOFF
148
    unsigned int numTiles = interactionCount[0];
149
150
    int pos = (int) (warp*(numTiles > maxTiles ? NUM_BLOCKS*((long long)NUM_BLOCKS+1)/2 : (long)numTiles)/totalWarps);
    int end = (int) ((warp+1)*(numTiles > maxTiles ? NUM_BLOCKS*((long long)NUM_BLOCKS+1)/2 : (long)numTiles)/totalWarps);
151
#else
152
153
    int pos = (int) (warp*(long long)numTiles/totalWarps);
    int end = (int) ((warp+1)*(long long)numTiles/totalWarps);
154
155
156
157
#endif
    int skipBase = 0;
    int currentSkipIndex = tbx;
    __shared__ int atomIndices[THREAD_BLOCK_SIZE];
158
    __shared__ volatile int skipTiles[THREAD_BLOCK_SIZE];
159
160
161
162
163
164
165
166
    skipTiles[threadIdx.x] = -1;
    
    while (pos < end) {
        real value = 0;
        bool includeTile = true;
        
        // Extract the coordinates of this tile.
        
167
        int x, y;
168
169
170
        bool singlePeriodicCopy = false;
#ifdef USE_CUTOFF
        if (numTiles <= maxTiles) {
171
            x = tiles[pos];
172
173
174
175
            real4 blockSizeX = blockSize[x];
            singlePeriodicCopy = (0.5f*periodicBoxSize.x-blockSizeX.x >= CUTOFF &&
                                  0.5f*periodicBoxSize.y-blockSizeX.y >= CUTOFF &&
                                  0.5f*periodicBoxSize.z-blockSizeX.z >= CUTOFF);
176
177
178
179
        }
        else
#endif
        {
180
            y = (int) floor(NUM_BLOCKS+0.5f-SQRT((NUM_BLOCKS+0.5f)*(NUM_BLOCKS+0.5f)-2*pos));
181
182
183
184
185
            x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
            if (x < y || x >= NUM_BLOCKS) { // Occasionally happens due to roundoff error.
                y += (x < y ? -1 : 1);
                x = (pos-y*NUM_BLOCKS+y*(y+1)/2);
            }
186

187
            // Skip over tiles that have exclusions, since they were already processed.
188

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            while (skipTiles[tbx+TILE_SIZE-1] < pos) {
                if (skipBase+tgx < NUM_TILES_WITH_EXCLUSIONS) {
                    ushort2 tile = exclusionTiles[skipBase+tgx];
                    skipTiles[threadIdx.x] = tile.x + tile.y*NUM_BLOCKS - tile.y*(tile.y+1)/2;
                }
                else
                    skipTiles[threadIdx.x] = end;
                skipBase += TILE_SIZE;            
                currentSkipIndex = tbx;
            }
            while (skipTiles[currentSkipIndex] < pos)
                currentSkipIndex++;
            includeTile = (skipTiles[currentSkipIndex] != pos);
        }
        if (includeTile) {
            unsigned int atom1 = x*TILE_SIZE + tgx;

            // Load atom data for this tile.
            
208
            real4 pos1 = posq[atom1];
209
210
211
212
213
214
215
216
217
            LOAD_ATOM1_PARAMETERS
            const unsigned int localAtomIndex = threadIdx.x;
#ifdef USE_CUTOFF
            unsigned int j = (numTiles <= maxTiles ? interactingAtoms[pos*TILE_SIZE+tgx] : y*TILE_SIZE + tgx);
#else
            unsigned int j = y*TILE_SIZE + tgx;
#endif
            atomIndices[threadIdx.x] = j;
            if (j < PADDED_NUM_ATOMS) {
218
219
                real4 tempPosq = posq[j];
                localData[localAtomIndex].pos = make_real3(tempPosq.x, tempPosq.y, tempPosq.z);
220
221
222
223
224
225
226
                LOAD_LOCAL_PARAMETERS_FROM_GLOBAL
                localData[localAtomIndex].value = 0;
            }
#ifdef USE_PERIODIC
            if (singlePeriodicCopy) {
                // The box is small enough that we can just translate all the atoms into a single periodic
                // box, then skip having to apply periodic boundary conditions later.
227

228
                real4 blockCenterX = blockCenter[x];
229
230
                APPLY_PERIODIC_TO_POS_WITH_CENTER(pos1, blockCenterX)
                APPLY_PERIODIC_TO_POS_WITH_CENTER(localData[threadIdx.x].pos, blockCenterX)
231
232
233
                unsigned int tj = tgx;
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
                    int atom2 = tbx+tj;
234
235
                    real3 pos2 = localData[atom2].pos;
                    real3 delta = make_real3(pos2.x-pos1.x, pos2.y-pos1.y, pos2.z-pos1.z);
236
237
238
                    real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                    if (r2 < CUTOFF_SQUARED) {
                        real invR = RSQRT(r2);
peastman's avatar
peastman committed
239
                        real r = r2*invR;
240
241
242
243
244
245
                        LOAD_ATOM2_PARAMETERS
                        atom2 = atomIndices[tbx+tj];
                        real tempValue1 = 0;
                        real tempValue2 = 0;
                        if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
                            COMPUTE_VALUE
246
                        }
247
248
                        value += tempValue1;
                        localData[tbx+tj].value += tempValue2;
249
                    }
250
                    tj = (tj + 1) & (TILE_SIZE - 1);
251
                }
252
253
            }
            else
254
#endif
255
256
            {
                // We need to apply periodic boundary conditions separately for each interaction.
257

258
259
260
                unsigned int tj = tgx;
                for (unsigned int j = 0; j < TILE_SIZE; j++) {
                    int atom2 = tbx+tj;
261
262
                    real3 pos2 = localData[atom2].pos;
                    real3 delta = make_real3(pos2.x-pos1.x, pos2.y-pos1.y, pos2.z-pos1.z);
263
#ifdef USE_PERIODIC
264
                    APPLY_PERIODIC_TO_DELTA(delta)
265
#endif
266
                    real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
267
#ifdef USE_CUTOFF
268
                    if (r2 < CUTOFF_SQUARED) {
269
270
#endif
                        real invR = RSQRT(r2);
peastman's avatar
peastman committed
271
                        real r = r2*invR;
272
                        LOAD_ATOM2_PARAMETERS
273
                        atom2 = atomIndices[tbx+tj];
274
275
276
277
278
279
280
281
282
                        real tempValue1 = 0;
                        real tempValue2 = 0;
                        if (atom1 < NUM_ATOMS && atom2 < NUM_ATOMS) {
                            COMPUTE_VALUE
                        }
                        value += tempValue1;
                        localData[tbx+tj].value += tempValue2;
#ifdef USE_CUTOFF
                    }
283
284
#endif
                    tj = (tj + 1) & (TILE_SIZE - 1);
285
286
287
                }
            }
        
288
289
290
291
292
293
294
295
296
297
            // Write results.

            atomicAdd(&global_value[atom1], static_cast<unsigned long long>((long long) (value*0x100000000)));
#ifdef USE_CUTOFF
            unsigned int atom2 = atomIndices[threadIdx.x];
#else
            unsigned int atom2 = y*TILE_SIZE + tgx;
#endif
            if (atom2 < PADDED_NUM_ATOMS)
                atomicAdd(&global_value[atom2], static_cast<unsigned long long>((long long) (localData[threadIdx.x].value*0x100000000)));
298
299
        }
        pos++;
300
    }
301
}