customNonbondedGroups.cu 8.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
typedef struct {
    real x, y, z;
    real q;
    real fx, fy, fz;
    ATOM_PARAMETER_DATA
#ifndef PARAMETER_SIZE_IS_EVEN
    real padding;
#endif
} AtomData;

11
12
13
14
15
16
17
18
19
20
21
22
// Find the maximum of a value across all threads in a warp, and return that to
// every thread.  This is only needed on Volta and later.  On earlier architectures,
// we can just return the value that was passed in.
__device__ int reduceMax(int val) {
#if __CUDA_ARCH__ >= 700
  for (int mask = 16; mask > 0; mask /= 2) 
    val = max(val, __shfl_xor(val, mask));
#endif
  return val;
}


23
extern "C" __global__ void computeInteractionGroups(
24
        unsigned long long* __restrict__ forceBuffers, mixed* __restrict__ energyBuffer, const real4* __restrict__ posq, const int4* __restrict__ groupData,
25
        const int* __restrict__ numGroupTiles, bool useNeighborList,
26
        real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ
27
28
29
30
31
        PARAMETER_ARGUMENTS) {
    const unsigned int totalWarps = (blockDim.x*gridDim.x)/TILE_SIZE;
    const unsigned int warp = (blockIdx.x*blockDim.x+threadIdx.x)/TILE_SIZE; // global warpIndex
    const unsigned int tgx = threadIdx.x & (TILE_SIZE-1); // index within the warp
    const unsigned int tbx = threadIdx.x - tgx;           // block warpIndex
32
    mixed energy = 0;
33
    INIT_DERIVATIVES
34
    __shared__ AtomData localData[LOCAL_MEMORY_SIZE];
35

36
37
    const unsigned int startTile = (useNeighborList ? warp*numGroupTiles[0]/totalWarps : FIRST_TILE+warp*(LAST_TILE-FIRST_TILE)/totalWarps);
    const unsigned int endTile = (useNeighborList ? (warp+1)*numGroupTiles[0]/totalWarps : FIRST_TILE+(warp+1)*(LAST_TILE-FIRST_TILE)/totalWarps);
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    for (int tile = startTile; tile < endTile; tile++) {
        const int4 atomData = groupData[TILE_SIZE*tile+tgx];
        const int atom1 = atomData.x;
        const int atom2 = atomData.y;
        const int rangeStart = atomData.z&0xFFFF;
        const int rangeEnd = (atomData.z>>16)&0xFFFF;
        const int exclusions = atomData.w;
        real4 posq1 = posq[atom1];
        LOAD_ATOM1_PARAMETERS
        real3 force = make_real3(0);
        real4 posq2 = posq[atom2];
        localData[threadIdx.x].x = posq2.x;
        localData[threadIdx.x].y = posq2.y;
        localData[threadIdx.x].z = posq2.z;
        localData[threadIdx.x].q = posq2.w;
        LOAD_LOCAL_PARAMETERS
        localData[threadIdx.x].fx = 0.0f;
        localData[threadIdx.x].fy = 0.0f;
        localData[threadIdx.x].fz = 0.0f;
        int tj = tgx;
58
59
60
61
62
63
64
65
        int rangeStop = rangeStart + reduceMax(rangeEnd-rangeStart);
        SYNC_WARPS;
        for (int j = rangeStart; j < rangeStop; j++) {
            if (j < rangeEnd) {
                bool isExcluded = (((exclusions>>tj)&1) == 0);
                int localIndex = tbx+tj;
                posq2 = make_real4(localData[localIndex].x, localData[localIndex].y, localData[localIndex].z, localData[localIndex].q);
                real3 delta = make_real3(posq2.x-posq1.x, posq2.y-posq1.y, posq2.z-posq1.z);
66
#ifdef USE_PERIODIC
67
                APPLY_PERIODIC_TO_DELTA(delta)
68
#endif
69
                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
70
#ifdef USE_CUTOFF
71
                if (!isExcluded && r2 < CUTOFF_SQUARED) {
72
#endif
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
                    real invR = RSQRT(r2);
                    real r = r2*invR;
                    LOAD_ATOM2_PARAMETERS
                    real dEdR = 0.0f;
                    real tempEnergy = 0.0f;
                    const real interactionScale = 1.0f;
                    COMPUTE_INTERACTION
                    energy += tempEnergy;
                    delta *= dEdR;
                    force.x -= delta.x;
                    force.y -= delta.y;
                    force.z -= delta.z;
                    localData[localIndex].fx += delta.x;
                    localData[localIndex].fy += delta.y;
                    localData[localIndex].fz += delta.z;
88
#ifdef USE_CUTOFF
89
                }
90
#endif
91
92
93
                tj = (tj == rangeEnd-1 ? rangeStart : tj+1);
            }
            SYNC_WARPS;
94
95
96
97
98
99
        }
        if (exclusions != 0) {
            atomicAdd(&forceBuffers[atom1], static_cast<unsigned long long>((long long) (force.x*0x100000000)));
            atomicAdd(&forceBuffers[atom1+PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (force.y*0x100000000)));
            atomicAdd(&forceBuffers[atom1+2*PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (force.z*0x100000000)));
        }
peastman's avatar
peastman committed
100
101
102
        atomicAdd(&forceBuffers[atom2], static_cast<unsigned long long>((long long) (localData[threadIdx.x].fx*0x100000000)));
        atomicAdd(&forceBuffers[atom2+PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (localData[threadIdx.x].fy*0x100000000)));
        atomicAdd(&forceBuffers[atom2+2*PADDED_NUM_ATOMS], static_cast<unsigned long long>((long long) (localData[threadIdx.x].fz*0x100000000)));
103
        SYNC_WARPS;
104
105
    }
    energyBuffer[blockIdx.x*blockDim.x+threadIdx.x] += energy;
106
    SAVE_DERIVATIVES
107
}
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

/**
 * If the neighbor list needs to be rebuilt, reset the number of tiles to 0.  This is
 * executed by a single thread.
 */
extern "C" __global__  void prepareToBuildNeighborList(int* __restrict__ rebuildNeighborList, int* __restrict__ numGroupTiles) {
    if (rebuildNeighborList[0] == 1)
        numGroupTiles[0] = 0;
}

/**
 * Filter the list of tiles to include only ones that have interactions within the
 * padded cutoff.
 */
extern "C" __global__  void buildNeighborList(int* __restrict__ rebuildNeighborList, int* __restrict__ numGroupTiles,
        const real4* __restrict__ posq, const int4* __restrict__ groupData, int4* __restrict__ filteredGroupData,
        real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ) {
    
    // If the neighbor list doesn't need to be rebuilt on this step, return immediately.
    
    if (rebuildNeighborList[0] == 0)
        return;

    const unsigned int totalWarps = (blockDim.x*gridDim.x)/TILE_SIZE;
    const unsigned int warp = (blockIdx.x*blockDim.x+threadIdx.x)/TILE_SIZE; // global warpIndex
    const unsigned int local_warp = threadIdx.x/TILE_SIZE; // local warpIndex
    const unsigned int tgx = threadIdx.x & (TILE_SIZE-1); // index within the warp
    const unsigned int tbx = threadIdx.x - tgx;           // block warpIndex

    __shared__ real4 localPos[LOCAL_MEMORY_SIZE];
    __shared__ volatile bool anyInteraction[WARPS_IN_BLOCK];
    __shared__ volatile int tileIndex[WARPS_IN_BLOCK];

    const unsigned int startTile = warp*NUM_TILES/totalWarps;
    const unsigned int endTile = (warp+1)*NUM_TILES/totalWarps;
    for (int tile = startTile; tile < endTile; tile++) {
        const int4 atomData = groupData[TILE_SIZE*tile+tgx];
        const int atom1 = atomData.x;
        const int atom2 = atomData.y;
        const int rangeStart = atomData.z&0xFFFF;
        const int rangeEnd = (atomData.z>>16)&0xFFFF;
        const int exclusions = atomData.w;
        real4 posq1 = posq[atom1];
        localPos[threadIdx.x] = posq[atom2];
        if (tgx == 0)
            anyInteraction[local_warp] = false;
        int tj = tgx;
155
        int rangeStop = rangeStart + reduceMax(rangeEnd-rangeStart);
156
        SYNC_WARPS;
157
158
        for (int j = rangeStart; j < rangeStop && !anyInteraction[local_warp]; j++) {
            if (j < rangeEnd && tj < rangeEnd) {
159
160
161
162
163
164
165
166
167
168
169
170
171
172
                bool isExcluded = (((exclusions>>tj)&1) == 0);
                int localIndex = tbx+tj;
                real3 delta = make_real3(localPos[localIndex].x-posq1.x, localPos[localIndex].y-posq1.y, localPos[localIndex].z-posq1.z);
#ifdef USE_PERIODIC
                APPLY_PERIODIC_TO_DELTA(delta)
#endif
                real r2 = delta.x*delta.x + delta.y*delta.y + delta.z*delta.z;
                if (!isExcluded && r2 < PADDED_CUTOFF_SQUARED)
                    anyInteraction[local_warp] = true;
            }
            tj = (tj == rangeEnd-1 ? rangeStart : tj+1);
            SYNC_WARPS;
        }
        if (anyInteraction[local_warp]) {
peastman's avatar
peastman committed
173
            SYNC_WARPS;
174
175
176
177
178
179
180
            if (tgx == 0)
                tileIndex[local_warp] = atomicAdd(numGroupTiles, 1);
            SYNC_WARPS;
            filteredGroupData[TILE_SIZE*tileIndex[local_warp]+tgx] = atomData;
        }
    }
}