customNonbondedGroups.cc 9.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
typedef struct {
    real x, y, z;
    real q;
    real fx, fy, fz;
    ATOM_PARAMETER_DATA
#ifndef PARAMETER_SIZE_IS_EVEN
    real padding;
#endif
} AtomData;

11
12
13
14
/**
 * Find the maximum of a value across all threads in a warp, and return that to
 * every thread.
 */
15
16
17
18
19
20
DEVICE int reduceMax(int val, LOCAL_ARG int* temp) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
    // CUDA lets us do this slightly more efficiently by using shuffle operations.
    for (int mask = 16; mask > 0; mask /= 2)
        val = max(val, __shfl_xor_sync(0xffffffff, val, mask));
    return val;
21
22
23
24
#elif defined(USE_HIP)
    for (int mask = 16; mask > 0; mask /= 2)
        val = max(val, __shfl_xor(val, mask, 32));
    return val;
25
26
27
#else
    int indexInWarp = LOCAL_ID%32;
    temp[LOCAL_ID] = val;
28
29
    SYNC_WARPS;
    for (int offset = 16; offset > 0; offset /= 2) {
peastman's avatar
peastman committed
30
        if (indexInWarp < offset)
31
            temp[LOCAL_ID] = max(temp[LOCAL_ID], temp[LOCAL_ID+offset]);
32
33
        SYNC_WARPS;
    }
34
35
    return temp[LOCAL_ID-indexInWarp];
#endif
36
37
}

38
#ifndef SUPPORTS_64_BIT_ATOMICS
39
40
41
42
43
/**
 * This function is used on devices that don't support 64 bit atomics.  Multiple threads within
 * a single tile might have computed forces on the same atom.  This loops over them and makes sure
 * that only one thread updates the force on any given atom.
 */
44
45
void writeForces(GLOBAL real4* forceBuffers, LOCAL AtomData* localData, int atomIndex) {
    localData[LOCAL_ID].x = atomIndex;
46
    SYNC_WARPS;
47
48
    real4 forceSum = make_real4(0);
    int start = (LOCAL_ID/TILE_SIZE)*TILE_SIZE;
49
50
51
52
53
    int end = start+32;
    bool isFirst = true;
    for (int i = start; i < end; i++)
        if (localData[i].x == atomIndex) {
            forceSum += (real4) (localData[i].fx, localData[i].fy, localData[i].fz, 0);
54
            isFirst &= (i >= LOCAL_ID);
55
        }
56
    const unsigned int warp = GLOBAL_ID/TILE_SIZE;
57
58
59
60
61
    unsigned int offset = atomIndex + warp*PADDED_NUM_ATOMS;
    if (isFirst)
        forceBuffers[offset] += forceSum;
    SYNC_WARPS;
}
62
#endif
63

64
KERNEL void computeInteractionGroups(
65
#ifdef SUPPORTS_64_BIT_ATOMICS
66
        GLOBAL mm_ulong* RESTRICT forceBuffers,
67
#else
68
        GLOBAL real4* RESTRICT forceBuffers,
69
#endif
70
71
        GLOBAL mixed* RESTRICT energyBuffer, GLOBAL const real4* RESTRICT posq, GLOBAL const int4* RESTRICT groupData,
        GLOBAL const int* RESTRICT numGroupTiles, int useNeighborList,
72
        real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ
73
        PARAMETER_ARGUMENTS) {
74
75
76
77
    const unsigned int totalWarps = GLOBAL_SIZE/TILE_SIZE;
    const unsigned int warp = GLOBAL_ID/TILE_SIZE; // global warpIndex
    const unsigned int tgx = LOCAL_ID & (TILE_SIZE-1); // index within the warp
    const unsigned int tbx = LOCAL_ID - tgx;           // block warpIndex
78
    mixed energy = 0;
79
    INIT_DERIVATIVES
80
81
    LOCAL AtomData localData[LOCAL_MEMORY_SIZE];
    LOCAL int reductionBuffer[LOCAL_MEMORY_SIZE];
82

83
84
    const unsigned int startTile = (useNeighborList ? warp*numGroupTiles[0]/totalWarps : FIRST_TILE+warp*(LAST_TILE-FIRST_TILE)/totalWarps);
    const unsigned int endTile = (useNeighborList ? (warp+1)*numGroupTiles[0]/totalWarps : FIRST_TILE+(warp+1)*(LAST_TILE-FIRST_TILE)/totalWarps);
85
86
87
88
89
90
91
92
93
    for (int tile = startTile; tile < endTile; tile++) {
        const int4 atomData = groupData[TILE_SIZE*tile+tgx];
        const int atom1 = atomData.x;
        const int atom2 = atomData.y;
        const int rangeStart = atomData.z&0xFFFF;
        const int rangeEnd = (atomData.z>>16)&0xFFFF;
        const int exclusions = atomData.w;
        real4 posq1 = posq[atom1];
        LOAD_ATOM1_PARAMETERS
94
        real3 force = make_real3(0);
95
        real4 posq2 = posq[atom2];
96
97
98
99
        localData[LOCAL_ID].x = posq2.x;
        localData[LOCAL_ID].y = posq2.y;
        localData[LOCAL_ID].z = posq2.z;
        localData[LOCAL_ID].q = posq2.w;
100
        LOAD_LOCAL_PARAMETERS
101
102
103
        localData[LOCAL_ID].fx = 0.0f;
        localData[LOCAL_ID].fy = 0.0f;
        localData[LOCAL_ID].fz = 0.0f;
104
        int tj = tgx;
105
        int rangeStop = rangeStart + reduceMax(rangeEnd-rangeStart, reductionBuffer);
106
        SYNC_WARPS;
107
108
        for (int j = rangeStart; j < rangeStop; j++) {
            if (j < rangeEnd) {
peastman's avatar
peastman committed
109
110
                bool isExcluded = (((exclusions>>tj)&1) == 0);
                int localIndex = tbx+tj;
111
112
                posq2 = make_real4(localData[localIndex].x, localData[localIndex].y, localData[localIndex].z, localData[localIndex].q);
                real3 delta = make_real3(posq2.x-posq1.x, posq2.y-posq1.y, posq2.z-posq1.z);
113
#ifdef USE_PERIODIC
114
                APPLY_PERIODIC_TO_DELTA(delta)
115
#endif
peastman's avatar
peastman committed
116
                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
117
#ifdef USE_CUTOFF
peastman's avatar
peastman committed
118
                if (!isExcluded && r2 < CUTOFF_SQUARED) {
119
#endif
peastman's avatar
peastman committed
120
                    real invR = RSQRT(r2);
peastman's avatar
peastman committed
121
                    real r = r2*invR;
peastman's avatar
peastman committed
122
123
124
                    LOAD_ATOM2_PARAMETERS
                    real dEdR = 0.0f;
                    real tempEnergy = 0.0f;
125
                    const real interactionScale = 1.0f;
peastman's avatar
peastman committed
126
127
128
                    COMPUTE_INTERACTION
                    energy += tempEnergy;
                    delta *= dEdR;
129
130
131
                    force.x -= delta.x;
                    force.y -= delta.y;
                    force.z -= delta.z;
peastman's avatar
peastman committed
132
133
134
                    localData[localIndex].fx += delta.x;
                    localData[localIndex].fy += delta.y;
                    localData[localIndex].fz += delta.z;
135
#ifdef USE_CUTOFF
peastman's avatar
peastman committed
136
                }
137
#endif
138
                tj = (tj == rangeEnd-1 ? rangeStart : tj+1);
peastman's avatar
peastman committed
139
            }
140
141
            SYNC_WARPS;
        }
142
#ifdef SUPPORTS_64_BIT_ATOMICS
143
        if (exclusions != 0) {
144
145
146
            ATOMIC_ADD(&forceBuffers[atom1], (mm_ulong) realToFixedPoint(force.x));
            ATOMIC_ADD(&forceBuffers[atom1+PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(force.y));
            ATOMIC_ADD(&forceBuffers[atom1+2*PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(force.z));
147
        }
148
149
150
        ATOMIC_ADD(&forceBuffers[atom2], (mm_ulong) realToFixedPoint(localData[LOCAL_ID].fx));
        ATOMIC_ADD(&forceBuffers[atom2+PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(localData[LOCAL_ID].fy));
        ATOMIC_ADD(&forceBuffers[atom2+2*PADDED_NUM_ATOMS], (mm_ulong) realToFixedPoint(localData[LOCAL_ID].fz));
151
        SYNC_WARPS;
152
153
#else
        writeForces(forceBuffers, localData, atom2);
154
155
156
        localData[LOCAL_ID].fx = force.x;
        localData[LOCAL_ID].fy = force.y;
        localData[LOCAL_ID].fz = force.z;
157
158
        writeForces(forceBuffers, localData, atom1);
#endif
159
    }
160
    energyBuffer[GLOBAL_ID] += energy;
161
    SAVE_DERIVATIVES
162
}
163
164
165
166
167

/**
 * If the neighbor list needs to be rebuilt, reset the number of tiles to 0.  This is
 * executed by a single thread.
 */
168
KERNEL void prepareToBuildNeighborList(GLOBAL int* RESTRICT rebuildNeighborList, GLOBAL int* RESTRICT numGroupTiles) {
169
170
171
172
173
174
175
176
    if (rebuildNeighborList[0] == 1)
        numGroupTiles[0] = 0;
}

/**
 * Filter the list of tiles to include only ones that have interactions within the
 * padded cutoff.
 */
177
178
KERNEL void buildNeighborList(GLOBAL int* RESTRICT rebuildNeighborList, GLOBAL int* RESTRICT numGroupTiles,
        GLOBAL const real4* RESTRICT posq, GLOBAL const int4* RESTRICT groupData, GLOBAL int4* RESTRICT filteredGroupData,
179
180
181
182
183
184
185
        real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ) {
    
    // If the neighbor list doesn't need to be rebuilt on this step, return immediately.
    
    if (rebuildNeighborList[0] == 0)
        return;

186
187
188
189
190
191
192
193
194
    const unsigned int totalWarps = GLOBAL_SIZE/TILE_SIZE;
    const unsigned int warp = GLOBAL_ID/TILE_SIZE; // global warpIndex
    const unsigned int local_warp = LOCAL_ID/TILE_SIZE; // local warpIndex
    const unsigned int tgx = LOCAL_ID & (TILE_SIZE-1); // index within the warp
    const unsigned int tbx = LOCAL_ID - tgx;           // block warpIndex
    LOCAL real4 localPos[LOCAL_MEMORY_SIZE];
    LOCAL volatile bool anyInteraction[WARPS_IN_BLOCK];
    LOCAL volatile int tileIndex[WARPS_IN_BLOCK];
    LOCAL int reductionBuffer[LOCAL_MEMORY_SIZE];
195
196
197
198
199
200
201
202
203
204
205

    const unsigned int startTile = warp*NUM_TILES/totalWarps;
    const unsigned int endTile = (warp+1)*NUM_TILES/totalWarps;
    for (int tile = startTile; tile < endTile; tile++) {
        const int4 atomData = groupData[TILE_SIZE*tile+tgx];
        const int atom1 = atomData.x;
        const int atom2 = atomData.y;
        const int rangeStart = atomData.z&0xFFFF;
        const int rangeEnd = (atomData.z>>16)&0xFFFF;
        const int exclusions = atomData.w;
        real4 posq1 = posq[atom1];
206
        localPos[LOCAL_ID] = posq[atom2];
207
208
209
        if (tgx == 0)
            anyInteraction[local_warp] = false;
        int tj = tgx;
210
        int rangeStop = rangeStart + reduceMax(rangeEnd-rangeStart, reductionBuffer);
211
        SYNC_WARPS;
212
        for (int j = rangeStart; j < rangeStop && !anyInteraction[local_warp]; j++) {
peastman's avatar
peastman committed
213
            SYNC_WARPS;
214
            if (j < rangeEnd && tj < rangeEnd) {
215
216
                bool isExcluded = (((exclusions>>tj)&1) == 0);
                int localIndex = tbx+tj;
217
                real3 delta = make_real3(localPos[localIndex].x-posq1.x, localPos[localIndex].y-posq1.y, localPos[localIndex].z-posq1.z);
218
219
220
221
222
223
224
225
226
227
228
#ifdef USE_PERIODIC
                APPLY_PERIODIC_TO_DELTA(delta)
#endif
                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                if (!isExcluded && r2 < PADDED_CUTOFF_SQUARED)
                    anyInteraction[local_warp] = true;
            }
            tj = (tj == rangeEnd-1 ? rangeStart : tj+1);
            SYNC_WARPS;
        }
        if (anyInteraction[local_warp]) {
peastman's avatar
peastman committed
229
            SYNC_WARPS;
230
            if (tgx == 0)
231
                tileIndex[local_warp] = ATOMIC_ADD(numGroupTiles, 1);
232
233
234
235
236
            SYNC_WARPS;
            filteredGroupData[TILE_SIZE*tileIndex[local_warp]+tgx] = atomData;
        }
    }
}