CpuNeighborList.cpp 32.4 KB
Newer Older
1
2
3
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
Evan Pretti's avatar
Evan Pretti committed
4
5
 * This is part of the OpenMM molecular simulation toolkit.                   *
 * See https://openmm.org/development.                                        *
6
 *                                                                            *
7
 * Portions copyright (c) 2013-2022 Stanford University and the Authors.      *
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * Permission is hereby granted, free of charge, to any person obtaining a    *
 * copy of this software and associated documentation files (the "Software"), *
 * to deal in the Software without restriction, including without limitation  *
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,   *
 * and/or sell copies of the Software, and to permit persons to whom the      *
 * Software is furnished to do so, subject to the following conditions:       *
 *                                                                            *
 * The above copyright notice and this permission notice shall be included in *
 * all copies or substantial portions of the Software.                        *
 *                                                                            *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,   *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL    *
 * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,    *
 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR      *
 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE  *
 * USE OR OTHER DEALINGS IN THE SOFTWARE.                                     *
 * -------------------------------------------------------------------------- */

30
#include "CpuNeighborList.h"
31
#include "openmm/internal/hardware.h"
32
#include "openmm/internal/vectorize.h"
33
#include "hilbert.h"
34
#include <algorithm>
35
36
37
38
39
40
41
42
43
44
45
#include <set>
#include <map>
#include <cmath>

using namespace std;

namespace OpenMM {

class VoxelIndex 
{
public:
46
    VoxelIndex() : y(0), z(0) {
47
    }
48
    VoxelIndex(int y, int z) : y(y), z(z) {
49
    }
50
    int y;
51
    int z;
52
53
};

54
55
56
57
58
/**
 * This data structure organizes the particles spatially.  It divides them into bins along the x and y axes,
 * then sorts each bin along the z axis so ranges can be identified quickly with a binary search.
 */
class CpuNeighborList::Voxels {
59
public:
peastman's avatar
peastman committed
60
    Voxels(int blockSize, float vsy, float vsz, float miny, float maxy, float minz, float maxz, const Vec3* boxVectors, bool usePeriodic) :
61
62
            blockSize(blockSize), voxelSizeY(vsy), voxelSizeZ(vsz), miny(miny), maxy(maxy), minz(minz), maxz(maxz), usePeriodic(usePeriodic) {
        for (int i = 0; i < 3; i++)
63
64
65
66
67
68
            for (int j = 0; j < 3; j++) {
                // Copying to a volatile temporary variable is a workaround for
                // a bug in GCC9 on PPC.
                volatile float temp = (float) boxVectors[i][j];
                periodicBoxVectors[i][j] = temp;
            }
69
70
71
72
73
74
75
76
77
        periodicBoxSize[0] = (float) boxVectors[0][0];
        periodicBoxSize[1] = (float) boxVectors[1][1];
        periodicBoxSize[2] = (float) boxVectors[2][2];
        recipBoxSize[0] = (float) (1/boxVectors[0][0]);
        recipBoxSize[1] = (float) (1/boxVectors[1][1]);
        recipBoxSize[2] = (float) (1/boxVectors[2][2]);
        triclinic = (boxVectors[0][1] != 0.0 || boxVectors[0][2] != 0.0 ||
                     boxVectors[1][0] != 0.0 || boxVectors[1][2] != 0.0 ||
                     boxVectors[2][0] != 0.0 || boxVectors[2][1] != 0.0);
78
        if (usePeriodic) {
79
80
81
82
            ny = (int) floorf(boxVectors[1][1]/voxelSizeY+0.5f);
            nz = (int) floorf(boxVectors[2][2]/voxelSizeZ+0.5f);
            voxelSizeY = boxVectors[1][1]/ny;
            voxelSizeZ = boxVectors[2][2]/nz;
83
84
        }
        else {
85
86
            ny = max(1, min(500, (int) floorf((maxy-miny)/voxelSizeY+0.5f)));
            nz = max(1, min(500, (int) floorf((maxz-minz)/voxelSizeZ+0.5f)));
87
88
            if (maxy > miny)
                voxelSizeY = (maxy-miny)/ny;
89
90
            if (maxz > minz)
                voxelSizeZ = (maxz-minz)/nz;
91
        }
92
        bins.resize(ny);
93
        for (int i = 0; i < ny; i++)
94
            bins[i].resize(nz);
95
96
    }

97
98
99
100
    /**
     * Insert a particle into the voxel data structure.
     */
    void insert(const int& atom, const float* location) {
101
        VoxelIndex voxelIndex = getVoxelIndex(location);
102
        bins[voxelIndex.y][voxelIndex.z].push_back(make_pair(location[0], atom));
103
104
105
    }
    
    /**
106
     * Sort the particles in each voxel by x coordinate.
107
108
     */
    void sortItems() {
109
110
        for (int i = 0; i < ny; i++)
            for (int j = 0; j < nz; j++)
111
                sort(bins[i][j].begin(), bins[i][j].end());
112
    }
113
    
114
    /**
115
     * Find the index of the first particle in voxel (y,z) whose x coordinate is >= the specified value.
116
     */
117
    int findLowerBound(int y, int z, double x, int lower, int upper) const {
118
        const vector<pair<float, int> >& bin = bins[y][z];
119
120
        while (lower < upper) {
            int middle = (lower+upper)/2;
121
            if (bin[middle].first < x)
122
123
124
125
126
127
128
129
                lower = middle+1;
            else
                upper = middle;
        }
        return lower;
    }
    
    /**
130
     * Find the index of the first particle in voxel (y,z) whose x coordinate is greater than the specified value.
131
     */
132
    int findUpperBound(int y, int z, double x, int lower, int upper) const {
133
        const vector<pair<float, int> >& bin = bins[y][z];
134
135
        while (lower < upper) {
            int middle = (lower+upper)/2;
136
            if (bin[middle].first > x)
137
138
139
140
141
142
                upper = middle;
            else
                lower = middle+1;
        }
        return upper;
    }
143

144
145
146
    /**
     * Get the voxel index containing a particular location.
     */
147
    VoxelIndex getVoxelIndex(const float* location) const {
148
        float yperiodic, zperiodic;
149
        if (!usePeriodic) {
150
            yperiodic = location[1]-miny;
151
            zperiodic = location[2]-minz;
152
153
        }
        else {
154
155
156
157
            float scale2 = floorf(location[2]*recipBoxSize[2]);
            yperiodic = location[1]-periodicBoxVectors[2][1]*scale2;
            zperiodic = location[2]-periodicBoxVectors[2][2]*scale2;
            float scale1 = floorf(yperiodic*recipBoxSize[1]);
158
            yperiodic -= periodicBoxVectors[1][1]*scale1;
159
        }
160
161
        int y = max(0, min(ny-1, int(floorf(yperiodic / voxelSizeY))));
        int z = max(0, min(nz-1, int(floorf(zperiodic / voxelSizeZ))));
162
        
163
        return VoxelIndex(y, z);
164
    }
165
        
166
    void getNeighbors(vector<int>& neighbors, int blockIndex, const fvec4& blockCenter, const fvec4& blockWidth, const vector<int>& sortedAtoms, vector<CpuNeighborList::BlockExclusionMask>& exclusions, float maxDistance, const vector<int>& blockAtoms, const vector<float>& blockAtomX, const vector<float>& blockAtomY, const vector<float>& blockAtomZ, const vector<float>& sortedPositions, const vector<VoxelIndex>& atomVoxelIndex) const {
167
168
        neighbors.resize(0);
        exclusions.resize(0);
169
        fvec4 boxSize(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2], 0);
170
171
172
173
174
175
        fvec4 invBoxSize(recipBoxSize[0], recipBoxSize[1], recipBoxSize[2], 0);
        fvec4 periodicBoxVec4[3];
        periodicBoxVec4[0] = fvec4(periodicBoxVectors[0][0], periodicBoxVectors[0][1], periodicBoxVectors[0][2], 0);
        periodicBoxVec4[1] = fvec4(periodicBoxVectors[1][0], periodicBoxVectors[1][1], periodicBoxVectors[1][2], 0);
        periodicBoxVec4[2] = fvec4(periodicBoxVectors[2][0], periodicBoxVectors[2][1], periodicBoxVectors[2][2], 0);

176
        float maxDistanceSquared = maxDistance * maxDistance;
177
178
        float refineCutoff = maxDistance-max(max(blockWidth[0], blockWidth[1]), blockWidth[2]);
        float refineCutoffSquared = refineCutoff*refineCutoff;
179

180
181
        int dIndexY = int((maxDistance+blockWidth[1])/voxelSizeY)+1; // How may voxels away do we have to look?
        int dIndexZ = int((maxDistance+blockWidth[2])/voxelSizeZ)+1;
peastman's avatar
peastman committed
182
183
        if (usePeriodic) {
            dIndexY = min(ny/2, dIndexY);
184
            dIndexZ = min(nz/2, dIndexZ);
peastman's avatar
peastman committed
185
        }
186
187
188
        float centerPos[4];
        blockCenter.store(centerPos);
        VoxelIndex centerVoxelIndex = getVoxelIndex(centerPos);
189
190
191
192
193
194
195

        // Loop over voxels along the z axis.

        int startz = centerVoxelIndex.z-dIndexZ;
        int endz = centerVoxelIndex.z+dIndexZ;
        if (usePeriodic)
            endz = min(endz, startz+nz-1);
196
        else {
197
198
            startz = max(startz, 0);
            endz = min(endz, nz-1);
199
        }
200
        int lastSortedIndex = blockSize*(blockIndex+1);
201
        VoxelIndex voxelIndex(0, 0);
202
203
        for (int z = startz; z <= endz; ++z) {
            voxelIndex.z = z;
204
            if (usePeriodic)
205
206
207
208
                voxelIndex.z = (z < 0 ? z+nz : (z >= nz ? z-nz : z));

            // Loop over voxels along the y axis.

209
            float boxz = floor((float) z/nz);
210
211
212
213
214
215
216
217
218
219
220
221
            int starty = centerVoxelIndex.y-dIndexY;
            int endy = centerVoxelIndex.y+dIndexY;
            float yoffset = (float) (usePeriodic ? boxz*periodicBoxVectors[2][1] : 0);
            if (usePeriodic) {
                starty -= (int) ceil(yoffset/voxelSizeY);
                endy -= (int) floor(yoffset/voxelSizeY);
                endy = min(endy, starty+ny-1);
            }
            else {
                starty = max(starty, 0);
                endy = min(endy, ny-1);
            }
222
            for (int y = starty; y <= endy; ++y) {
223
224
225
                voxelIndex.y = y;
                if (usePeriodic)
                    voxelIndex.y = (y < 0 ? y+ny : (y >= ny ? y-ny : y));
226
                float boxy = floor((float) y/ny);
227
228
229
230
                
                // Identify the range of atoms within this bin we need to search.  When using periodic boundary
                // conditions, there may be two separate ranges.
                
231
232
                float minx = centerPos[0];
                float maxx = centerPos[0];
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
                if (usePeriodic && triclinic) {
                    for (int k = 0; k < (int) blockAtoms.size(); k++) {
                        const float* atomPos = &sortedPositions[4*(blockSize*blockIndex+k)];
                        fvec4 delta1(0, voxelSizeY*voxelIndex.y-atomPos[1], voxelSizeZ*voxelIndex.z-atomPos[2], 0);
                        fvec4 delta2 = delta1+fvec4(0, voxelSizeY, 0, 0);
                        fvec4 delta3 = delta1+fvec4(0, 0, voxelSizeZ, 0);
                        fvec4 delta4 = delta1+fvec4(0, voxelSizeY, voxelSizeZ, 0);
                        delta1 -= periodicBoxVec4[2]*floorf(delta1[2]*recipBoxSize[2]+0.5f);
                        delta1 -= periodicBoxVec4[1]*floorf(delta1[1]*recipBoxSize[1]+0.5f);
                        delta1 -= periodicBoxVec4[0]*floorf(delta1[0]*recipBoxSize[0]+0.5f);
                        delta2 -= periodicBoxVec4[2]*floorf(delta2[2]*recipBoxSize[2]+0.5f);
                        delta2 -= periodicBoxVec4[1]*floorf(delta2[1]*recipBoxSize[1]+0.5f);
                        delta2 -= periodicBoxVec4[0]*floorf(delta2[0]*recipBoxSize[0]+0.5f);
                        delta3 -= periodicBoxVec4[2]*floorf(delta3[2]*recipBoxSize[2]+0.5f);
                        delta3 -= periodicBoxVec4[1]*floorf(delta3[1]*recipBoxSize[1]+0.5f);
                        delta3 -= periodicBoxVec4[0]*floorf(delta3[0]*recipBoxSize[0]+0.5f);
                        delta4 -= periodicBoxVec4[2]*floorf(delta4[2]*recipBoxSize[2]+0.5f);
                        delta4 -= periodicBoxVec4[1]*floorf(delta4[1]*recipBoxSize[1]+0.5f);
                        delta4 -= periodicBoxVec4[0]*floorf(delta4[0]*recipBoxSize[0]+0.5f);
                        if (delta1[1] < 0 && delta1[1]+voxelSizeY > 0)
                            delta1 = fvec4(delta1[0], 0, delta1[2], 0);
                        if (delta1[2] < 0 && delta1[2]+voxelSizeZ > 0)
                            delta1 = fvec4(delta1[0], delta1[1], 0, 0);
                        if (delta3[1] < 0 && delta3[1]+voxelSizeY > 0)
                            delta3 = fvec4(delta3[0], 0, delta3[2], 0);
                        if (delta2[2] < 0 && delta2[2]+voxelSizeZ > 0)
                            delta2 = fvec4(delta2[0], delta2[1], 0, 0);
                        fvec4 delta = min(min(min(abs(delta1), abs(delta2)), abs(delta3)), abs(delta4));
                        float dy = (voxelIndex.y == atomVoxelIndex[k].y ? 0.0f : delta[1]);
                        float dz = (voxelIndex.z == atomVoxelIndex[k].z ? 0.0f : delta[2]);
                        float dist2 = maxDistanceSquared-dy*dy-dz*dz;
                        if (dist2 > 0) {
                            float dist = sqrtf(dist2);
                            minx = min(minx, atomPos[0]-dist-max(max(max(delta1[0], delta2[0]), delta3[0]), delta4[0]));
                            maxx = max(maxx, atomPos[0]+dist-min(min(min(delta1[0], delta2[0]), delta3[0]), delta4[0]));
                        }
269
                    }
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
                }
                else {
                    float xoffset = (float) (usePeriodic ? boxy*periodicBoxVectors[1][0]+boxz*periodicBoxVectors[2][0] : 0);
                    fvec4 offset(-xoffset, -yoffset+voxelSizeY*y+(usePeriodic ? 0.0f : miny), voxelSizeZ*z+(usePeriodic ? 0.0f : minz), 0);
                    for (int k = 0; k < (int) blockAtoms.size(); k++) {
                        const float* atomPos = &sortedPositions[4*(blockSize*blockIndex+k)];
                        fvec4 posVec(atomPos);
                        fvec4 delta1 = offset-posVec;
                        fvec4 delta2 = delta1+fvec4(0, voxelSizeY, voxelSizeZ, 0);
                        if (usePeriodic) {
                            delta1 -= round(delta1*invBoxSize)*boxSize;
                            delta2 -= round(delta2*invBoxSize)*boxSize;
                        }
                        fvec4 delta = min(abs(delta1), abs(delta2));
                        float dy = (y == atomVoxelIndex[k].y ? 0.0f : delta[1]);
                        float dz = (z == atomVoxelIndex[k].z ? 0.0f : delta[2]);
                        float dist2 = maxDistanceSquared-dy*dy-dz*dz;
                        if (dist2 > 0) {
                            float dist = sqrtf(dist2);
                            minx = min(minx, atomPos[0]-dist-xoffset);
                            maxx = max(maxx, atomPos[0]+dist-xoffset);
                        }
292
293
                    }
                }
294
                if (minx == maxx)
peastman's avatar
peastman committed
295
                    continue;
296
297
298
                bool needPeriodic = usePeriodic && (centerPos[1]-blockWidth[1] < maxDistance || centerPos[1]+blockWidth[1] > periodicBoxSize[1]-maxDistance ||
                                                    centerPos[2]-blockWidth[2] < maxDistance || centerPos[2]+blockWidth[2] > periodicBoxSize[2]-maxDistance ||
                                                    minx < 0.0f || maxx > periodicBoxVectors[0][0]);
299
                int numRanges;
300
301
                int rangeStart[2];
                int rangeEnd[2];
302
303
                int binSize = bins[voxelIndex.y][voxelIndex.z].size();
                rangeStart[0] = findLowerBound(voxelIndex.y, voxelIndex.z, minx, 0, binSize);
304
305
                if (needPeriodic) {
                    numRanges = 2;
306
307
308
309
                    rangeEnd[0] = findUpperBound(voxelIndex.y, voxelIndex.z, maxx, rangeStart[0], binSize);
                    if (rangeStart[0] > 0 && rangeEnd[0] < binSize)
                        numRanges = 1;
                    else if (rangeStart[0] > 0) {
310
                        rangeStart[1] = 0;
311
                        rangeEnd[1] = min(findUpperBound(voxelIndex.y, voxelIndex.z, maxx-periodicBoxSize[0], 0, rangeStart[0]), rangeStart[0]);
312
313
                    }
                    else {
314
                        rangeStart[1] = max(findLowerBound(voxelIndex.y, voxelIndex.z, minx+periodicBoxSize[0], rangeEnd[0], binSize), rangeEnd[0]);
315
                        rangeEnd[1] = bins[voxelIndex.y][voxelIndex.z].size();
316
317
                    }
                }
318
319
                else {
                    numRanges = 1;
320
                    rangeEnd[0] = findUpperBound(voxelIndex.y, voxelIndex.z, maxx, rangeStart[0], binSize);
321
                }
322
                bool periodicRectangular = (needPeriodic && !triclinic);
323
324
325
                
                // Loop over atoms and check to see if they are neighbors of this block.
                
326
                const vector<pair<float, int> >& voxelBins = bins[voxelIndex.y][voxelIndex.z];
327
328
                for (int range = 0; range < numRanges; range++) {
                    for (int item = rangeStart[range]; item < rangeEnd[range]; item++) {
329
                        const int sortedIndex = voxelBins[item].second;
330

peastman's avatar
peastman committed
331
                        // Avoid duplicate entries.
332
                        if (sortedIndex >= lastSortedIndex)
333
                            continue;
334
                        
335
                        fvec4 atomPos(&sortedPositions[4*sortedIndex]);
336
                        fvec4 delta = atomPos-blockCenter;
337
338
                        if (periodicRectangular)
                            delta -= round(delta*invBoxSize)*boxSize;
339
340
341
342
343
                        else if (needPeriodic) {
                            delta -= periodicBoxVec4[2]*floorf(delta[2]*recipBoxSize[2]+0.5f);
                            delta -= periodicBoxVec4[1]*floorf(delta[1]*recipBoxSize[1]+0.5f);
                            delta -= periodicBoxVec4[0]*floorf(delta[0]*recipBoxSize[0]+0.5f);
                        }
344
                        delta = max(0.0f, abs(delta)-blockWidth);
345
                        float dSquared = dot3(delta, delta);
346
347
                        if (dSquared > maxDistanceSquared)
                            continue;
peastman's avatar
peastman committed
348
                        
349
350
351
352
                        if (dSquared > refineCutoffSquared) {
                            // The distance is large enough that there might not be any actual interactions.
                            // Check individual atom pairs to be sure.
                            
353
354
355
356
357
                            bool anyInteraction = false;
                            for (int k = 0; k < (int) blockAtoms.size(); k += 4) {
                                fvec4 dx = fvec4(&blockAtomX[k])-atomPos[0];
                                fvec4 dy = fvec4(&blockAtomY[k])-atomPos[1];
                                fvec4 dz = fvec4(&blockAtomZ[k])-atomPos[2];
358
                                if (periodicRectangular) {
359
360
361
                                    dx -= round(dx*invBoxSize[0])*boxSize[0];
                                    dy -= round(dy*invBoxSize[1])*boxSize[1];
                                    dz -= round(dz*invBoxSize[2])*boxSize[2];
362
                                }
363
                                else if (needPeriodic) {
364
365
366
367
368
369
370
371
372
                                    fvec4 scale3 = floor(dz*recipBoxSize[2]+0.5f);
                                    dx -= scale3*periodicBoxVectors[2][0];
                                    dy -= scale3*periodicBoxVectors[2][1];
                                    dz -= scale3*periodicBoxVectors[2][2];
                                    fvec4 scale2 = floor(dy*recipBoxSize[1]+0.5f);
                                    dx -= scale2*periodicBoxVectors[1][0];
                                    dy -= scale2*periodicBoxVectors[1][1];
                                    fvec4 scale1 = floor(dx*recipBoxSize[0]+0.5f);
                                    dx -= scale1*periodicBoxVectors[0][0];
373
                                }
374
375
376
                                fvec4 r2 = dx*dx + dy*dy + dz*dz;
                                if (any(r2 < maxDistanceSquared)) {
                                    anyInteraction = true;
377
378
379
                                    break;
                                }
                            }
380
                            if (!anyInteraction)
381
382
383
384
385
386
                                continue;
                        }
                        
                        // Add this atom to the list of neighbors.
                        
                        neighbors.push_back(sortedAtoms[sortedIndex]);
387
                        if (sortedIndex < blockSize*blockIndex)
388
                            exclusions.push_back(0);
389
390
391
392
                        else {
                            int mask = (1<<blockSize)-1;
                            exclusions.push_back(mask & (mask<<(sortedIndex-blockSize*blockIndex)));
                        }
393
394
395
396
397
398
399
                    }
                }
            }
        }
    }

private:
400
    int blockSize;
401
402
403
404
405
    float voxelSizeY, voxelSizeZ;
    float miny, maxy, minz, maxz;
    int ny, nz;
    float periodicBoxSize[3], recipBoxSize[3];
    bool triclinic;
406
    float periodicBoxVectors[3][3];
407
    const bool usePeriodic;
peastman's avatar
peastman committed
408
    vector<vector<vector<pair<float, int> > > > bins;
409
410
};

411
CpuNeighborList::CpuNeighborList(int blockSize) : blockSize(blockSize) {
412
}
413

Evan Pretti's avatar
Evan Pretti committed
414
415
416
417
418
419
420
421
422
423
424
void CpuNeighborList::computeNeighborList(int numAtoms, const AlignedArray<float>& atomLocations, const vector<set<int> >& exclusions,
            const Vec3* periodicBoxVectors, bool usePeriodic, float maxDistance, ThreadPool& threads, const std::vector<int>* indices) {
    if (indices != NULL) {
        this->indices = indices->data();
        computeNeighborList<true>(numAtoms, atomLocations, exclusions, periodicBoxVectors, usePeriodic, maxDistance, threads);
    } else {
        computeNeighborList<false>(numAtoms, atomLocations, exclusions, periodicBoxVectors, usePeriodic, maxDistance, threads);
    }
}

template<bool USE_INDICES>
425
void CpuNeighborList::computeNeighborList(int numAtoms, const AlignedArray<float>& atomLocations, const vector<set<int> >& exclusions,
peastman's avatar
peastman committed
426
            const Vec3* periodicBoxVectors, bool usePeriodic, float maxDistance, ThreadPool& threads) {
427
    dense = false;
428
    int numBlocks = (numAtoms+blockSize-1)/blockSize;
429
430
431
    blockNeighbors.resize(numBlocks);
    blockExclusions.resize(numBlocks);
    sortedAtoms.resize(numAtoms);
432
    sortedPositions.resize(4*numAtoms);
433
    
434
    // Record the parameters for the threads.
435
    
436
437
    this->exclusions = &exclusions;
    this->atomLocations = &atomLocations[0];
438
439
440
    this->periodicBoxVectors[0] = periodicBoxVectors[0];
    this->periodicBoxVectors[1] = periodicBoxVectors[1];
    this->periodicBoxVectors[2] = periodicBoxVectors[2];
441
442
443
444
445
446
    this->numAtoms = numAtoms;
    this->usePeriodic = usePeriodic;
    this->maxDistance = maxDistance;
    
    // Identify the range of atom positions along each axis.
    
Evan Pretti's avatar
Evan Pretti committed
447
448
    int minPosIndex = USE_INDICES ? indices[0] : 0;
    fvec4 minPos(&atomLocations[4*minPosIndex]);
449
    fvec4 maxPos = minPos;
Evan Pretti's avatar
Evan Pretti committed
450
451
452
    for (int i = 1; i < numAtoms; i++) {
        int posIndex = USE_INDICES ? indices[i] : i;
        fvec4 pos(&atomLocations[4*posIndex]);
453
454
        minPos = min(minPos, pos);
        maxPos = max(maxPos, pos);
455
    }
456
457
458
459
460
461
462
463
464
465
    minx = minPos[0];
    maxx = maxPos[0];
    miny = minPos[1];
    maxy = maxPos[1];
    minz = minPos[2];
    maxz = maxPos[2];
    
    // Sort the atoms based on a Hilbert curve.
    
    atomBins.resize(numAtoms);
Evan Pretti's avatar
Evan Pretti committed
466
    threads.execute([&] (ThreadPool& threads, int threadIndex) { threadComputeNeighborList<USE_INDICES>(threads, threadIndex); });
467
    threads.waitForThreads();
468
469
    sort(atomBins.begin(), atomBins.end());

470
    // Build the voxel hash.
471

472
    float edgeSizeY, edgeSizeZ;
473
    if (!usePeriodic)
474
        edgeSizeY = edgeSizeZ = maxDistance; // TODO - adjust this as needed
475
    else {
476
477
        edgeSizeY = 0.6f*periodicBoxVectors[1][1]/floorf(periodicBoxVectors[1][1]/maxDistance);
        edgeSizeZ = 0.6f*periodicBoxVectors[2][2]/floorf(periodicBoxVectors[2][2]/maxDistance);
478
    }
479
    Voxels voxels(blockSize, edgeSizeY, edgeSizeZ, miny, maxy, minz, maxz, periodicBoxVectors, usePeriodic);
480
481
482
    for (int i = 0; i < numAtoms; i++) {
        int atomIndex = atomBins[i].second;
        sortedAtoms[i] = atomIndex;
Evan Pretti's avatar
Evan Pretti committed
483
484
        int atomPosIndex = USE_INDICES ? indices[atomIndex] : atomIndex;
        fvec4 atomPos(&atomLocations[4*atomPosIndex]);
485
        atomPos.store(&sortedPositions[4*i]);
Evan Pretti's avatar
Evan Pretti committed
486
        voxels.insert(i, &atomLocations[4*atomPosIndex]);
487
    }
488
489
    voxels.sortItems();
    this->voxels = &voxels;
490

491
492
    // Signal the threads to start running and wait for them to finish.
    
peastman's avatar
peastman committed
493
    atomicCounter = 0;
494
495
    threads.resumeThreads();
    threads.waitForThreads();
496
    
497
    // Add padding atoms to fill up the last block.
498
    
499
    int numPadding = numBlocks*blockSize-numAtoms;
500
    if (numPadding > 0) {
501
        const BlockExclusionMask mask = (~0) << (blockSize - numPadding);
502
503
        for (int i = 0; i < numPadding; i++)
            sortedAtoms.push_back(0);
504
        auto& exc = blockExclusions[blockExclusions.size()-1];
505
506
507
        for (int i = 0; i < (int) exc.size(); i++)
            exc[i] |= mask;
    }
508
509
}

510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
void CpuNeighborList::createDenseNeighborList(int numAtoms, const vector<set<int> >& exclusions) {
    dense = true;
    this->numAtoms = numAtoms;
    int numBlocks = (numAtoms+blockSize-1)/blockSize;
    blockExclusionIndices.resize(numBlocks);
    blockExclusions.resize(numBlocks);
    sortedAtoms.resize(numAtoms);
    for (int i = 0; i < numAtoms; i++)
        sortedAtoms[i] = i;
    for (int i = 0; i < numBlocks; i++) {
        // Build the list of exclusions for this block.

        int firstIndex = blockSize*i;
        int atomsInBlock = min(blockSize, numAtoms-firstIndex);
        map<int, int> exclusionMap;
        for (int j = 0; j < atomsInBlock; j++) {
            exclusionMap[firstIndex+j] = (1<<(j+1))-1;
        }
        for (int j = 0; j < atomsInBlock; j++) {
            const set<int>& atomExclusions = exclusions[firstIndex+j];
            const BlockExclusionMask mask = 1<<j;
            for (int exclusion : atomExclusions) {
                if (firstIndex <= exclusion) {
                    auto thisAtomFlags = exclusionMap.find(exclusion);
                    if (thisAtomFlags == exclusionMap.end())
                        exclusionMap[exclusion] = mask;
                    else
                        thisAtomFlags->second |= mask;
                }
            }
        }
        blockExclusionIndices[i].clear();
        blockExclusions[i].clear();
        for (auto flags : exclusionMap) {
            blockExclusionIndices[i].push_back(flags.first);
            blockExclusions[i].push_back(flags.second);
        }
    }

    // Add padding atoms to fill up the last block.

    int numPadding = numBlocks*blockSize-numAtoms;
    if (numPadding > 0) {
        const BlockExclusionMask mask = (~0) << (blockSize - numPadding);
        for (int i = 0; i < numPadding; i++)
            sortedAtoms.push_back(0);
        auto& exc = blockExclusions.back();
        for (int i = 0; i < (int) exc.size(); i++)
            exc[i] |= mask;
    }
}

562
int CpuNeighborList::getNumBlocks() const {
563
    return sortedAtoms.size()/blockSize;
564
565
}

566
567
568
569
int CpuNeighborList::getBlockSize() const {
    return blockSize;
}

Daniel Towner's avatar
Daniel Towner committed
570
const std::vector<int32_t>& CpuNeighborList::getSortedAtoms() const {
571
572
573
574
575
576
577
    return sortedAtoms;
}

const std::vector<int>& CpuNeighborList::getBlockNeighbors(int blockIndex) const {
    return blockNeighbors[blockIndex];
}

578
const std::vector<CpuNeighborList::BlockExclusionMask>& CpuNeighborList::getBlockExclusions(int blockIndex) const {
579
580
581
582
    return blockExclusions[blockIndex];
    
}

583
584
585
586
587
588
589
CpuNeighborList::NeighborIterator CpuNeighborList::getNeighborIterator(int blockIndex) const {
    if (dense)
        return NeighborIterator(blockIndex*blockSize, numAtoms, blockExclusionIndices[blockIndex], blockExclusions[blockIndex]);
    else
        return NeighborIterator(blockNeighbors[blockIndex], blockExclusions[blockIndex]);
}

Evan Pretti's avatar
Evan Pretti committed
590
template<bool USE_INDICES>
591
592
void CpuNeighborList::threadComputeNeighborList(ThreadPool& threads, int threadIndex) {
    // Compute the positions of atoms along the Hilbert curve.
593

594
595
596
597
598
    float binWidth = max(max(maxx-minx, maxy-miny), maxz-minz)/255.0f;
    float invBinWidth = 1.0f/binWidth;
    bitmask_t coords[3];
    int numThreads = threads.getNumThreads();
    for (int i = threadIndex; i < numAtoms; i += numThreads) {
Evan Pretti's avatar
Evan Pretti committed
599
600
        int posIndex = USE_INDICES ? indices[i] : i;
        const float* pos = &atomLocations[4*posIndex];
601
602
603
604
605
606
607
        coords[0] = (bitmask_t) ((pos[0]-minx)*invBinWidth);
        coords[1] = (bitmask_t) ((pos[1]-miny)*invBinWidth);
        coords[2] = (bitmask_t) ((pos[2]-minz)*invBinWidth);
        int bin = (int) hilbert_c2i(3, 8, coords);
        atomBins[i] = pair<int, int>(bin, i);
    }
    threads.syncThreads();
608

609
610
611
612
    // Compute this thread's subset of neighbors.

    int numBlocks = blockNeighbors.size();
    vector<int> blockAtoms;
613
    vector<float> blockAtomX(blockSize), blockAtomY(blockSize), blockAtomZ(blockSize);
614
    vector<VoxelIndex> atomVoxelIndex;
615
    while (true) {
peastman's avatar
peastman committed
616
        int i = atomicCounter++;
617
618
619
        if (i >= numBlocks)
            break;

620
621
        // Find the atoms in this block and compute their bounding box.
        
622
623
        int firstIndex = blockSize*i;
        int atomsInBlock = min(blockSize, numAtoms-firstIndex);
624
        blockAtoms.resize(atomsInBlock);
625
626
        atomVoxelIndex.resize(atomsInBlock);
        for (int j = 0; j < atomsInBlock; j++) {
627
            blockAtoms[j] = sortedAtoms[firstIndex+j];
Evan Pretti's avatar
Evan Pretti committed
628
629
            int posIndex = USE_INDICES ? indices[blockAtoms[j]] : blockAtoms[j];
            atomVoxelIndex[j] = voxels->getVoxelIndex(&atomLocations[4*posIndex]);
630
        }
631
        fvec4 minPos(&sortedPositions[4*firstIndex]);
632
633
        fvec4 maxPos = minPos;
        for (int j = 1; j < atomsInBlock; j++) {
634
            fvec4 pos(&sortedPositions[4*(firstIndex+j)]);
635
636
            minPos = min(minPos, pos);
            maxPos = max(maxPos, pos);
637
        }
638
639
640
641
642
643
644
645
646
647
648
        for (int j = 0; j < atomsInBlock; j++) {
            blockAtomX[j] = sortedPositions[4*(firstIndex+j)];
            blockAtomY[j] = sortedPositions[4*(firstIndex+j)+1];
            blockAtomZ[j] = sortedPositions[4*(firstIndex+j)+2];
        }
        for (int j = atomsInBlock; j < blockSize; j++) {
            blockAtomX[j] = 1e10;
            blockAtomY[j] = 1e10;
            blockAtomZ[j] = 1e10;
        }
        voxels->getNeighbors(blockNeighbors[i], i, (maxPos+minPos)*0.5f, (maxPos-minPos)*0.5f, sortedAtoms, blockExclusions[i], maxDistance, blockAtoms, blockAtomX, blockAtomY, blockAtomZ, sortedPositions, atomVoxelIndex);
649

650
        // Record the exclusions for this block.
651

652
        map<int, BlockExclusionMask> atomFlags;
653
654
        for (int j = 0; j < atomsInBlock; j++) {
            const set<int>& atomExclusions = (*exclusions)[sortedAtoms[firstIndex+j]];
655
            const BlockExclusionMask mask = 1<<j;
peastman's avatar
peastman committed
656
            for (int exclusion : atomExclusions) {
657
                const auto thisAtomFlags = atomFlags.find(exclusion);
658
                if (thisAtomFlags == atomFlags.end())
peastman's avatar
peastman committed
659
                    atomFlags[exclusion] = mask;
660
661
                else
                    thisAtomFlags->second |= mask;
662
663
            }
        }
664
665
666
        int numNeighbors = blockNeighbors[i].size();
        for (int k = 0; k < numNeighbors; k++) {
            int atomIndex = blockNeighbors[i][k];
667
            auto thisAtomFlags = atomFlags.find(atomIndex);
668
669
670
            if (thisAtomFlags != atomFlags.end())
                blockExclusions[i][k] |= thisAtomFlags->second;
        }
671
672
673
    }
}

674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
CpuNeighborList::NeighborIterator::NeighborIterator(const vector<int>& neighbors, const vector<BlockExclusionMask>& exclusions) :
        dense(false), neighbors(&neighbors), exclusions(&exclusions), currentIndex(-1) {
}

CpuNeighborList::NeighborIterator::NeighborIterator(int firstAtom, int lastAtom, const vector<int>& exclusionIndices, const vector<BlockExclusionMask>& exclusions) :
        dense(true), currentAtom(firstAtom-1), lastAtom(lastAtom), exclusionIndices(&exclusionIndices), exclusions(&exclusions), currentIndex(0) {
}

bool CpuNeighborList::NeighborIterator::next() {
    if (dense) {
        if (++currentAtom >= lastAtom)
            return false;
        if (currentIndex < exclusionIndices->size() && (*exclusionIndices)[currentIndex] == currentAtom)
            currentExclusions = (*exclusions)[currentIndex++];
        else
            currentExclusions = 0;
        return true;
    }
    else {
        if (++currentIndex < neighbors->size()) {
            currentAtom = (*neighbors)[currentIndex];
            currentExclusions = (*exclusions)[currentIndex];
            return true;
        }
        return false;
    }
}

int CpuNeighborList::NeighborIterator::getNeighbor() const {
    return currentAtom;
}

CpuNeighborList::BlockExclusionMask CpuNeighborList::NeighborIterator::getExclusions() const {
    return currentExclusions;
}

710
} // namespace OpenMM