CpuNeighborList.cpp 7.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#include "CpuNeighborList.h"
#include <set>
#include <map>
#include <cmath>

using namespace std;

namespace OpenMM {

static float periodicDifference(float val1, float val2, float period) {
    float diff = val1-val2;
    float base = floorf(diff/period+0.5f)*period;
    return diff-base;
}

// squared distance between two points
static float compPairDistanceSquared(const float* pos1, const float* pos2, const float* periodicBoxSize, bool usePeriodic) {
    float dx, dy, dz;
    if (!usePeriodic) {
        dx = pos2[0] - pos1[0];
        dy = pos2[1] - pos1[1];
        dz = pos2[2] - pos1[2];
    }
    else {
        dx = periodicDifference(pos2[0], pos1[0], periodicBoxSize[0]);
        dy = periodicDifference(pos2[1], pos1[1], periodicBoxSize[1]);
        dz = periodicDifference(pos2[2], pos1[2], periodicBoxSize[2]);
    }
    return dx*dx + dy*dy + dz*dz;
}

class VoxelIndex 
{
public:
    VoxelIndex(int xx, int yy, int zz) : x(xx), y(yy), z(zz) {}

    // operator<() needed for map
    bool operator<(const VoxelIndex& other) const {
        if      (x < other.x) return true;
        else if (x > other.x) return false;
        else if (y < other.y) return true;
        else if (y > other.y) return false;
        else if (z < other.z) return true;
        else return false;
    }

    int x;
    int y;
    int z;
};

typedef std::pair<const float*, int> VoxelItem;
typedef std::vector< VoxelItem > Voxel;

class VoxelHash
{
public:
    VoxelHash(float vsx, float vsy, float vsz, const float* periodicBoxSize, bool usePeriodic) :
            voxelSizeX(vsx), voxelSizeY(vsy), voxelSizeZ(vsz), periodicBoxSize(periodicBoxSize), usePeriodic(usePeriodic) {
        if (usePeriodic) {
            nx = (int) floorf(periodicBoxSize[0]/voxelSizeX+0.5f);
            ny = (int) floorf(periodicBoxSize[1]/voxelSizeY+0.5f);
            nz = (int) floorf(periodicBoxSize[2]/voxelSizeZ+0.5f);
        }
    }

    void insert(const int& item, const float* location)
    {
        VoxelIndex voxelIndex = getVoxelIndex(location);
        if (voxelMap.find(voxelIndex) == voxelMap.end()) voxelMap[voxelIndex] = Voxel(); 
        Voxel& voxel = voxelMap.find(voxelIndex)->second;
        voxel.push_back(VoxelItem(location, item));
    }


    VoxelIndex getVoxelIndex(const float* location) const {
        float xperiodic, yperiodic, zperiodic;
        if (!usePeriodic) {
            xperiodic = location[0];
            yperiodic = location[1];
            zperiodic = location[2];
        }
        else {
            xperiodic = location[0]-periodicBoxSize[0]*floorf(location[0]/periodicBoxSize[0]);
            yperiodic = location[1]-periodicBoxSize[1]*floorf(location[1]/periodicBoxSize[1]);
            zperiodic = location[2]-periodicBoxSize[2]*floorf(location[2]/periodicBoxSize[2]);
        }
        int x = int(floorf(xperiodic / voxelSizeX));
        int y = int(floorf(yperiodic / voxelSizeY));
        int z = int(floorf(zperiodic / voxelSizeZ));
        
        return VoxelIndex(x, y, z);
    }

    void getNeighbors(
            vector<pair<int, int> >& neighbors, 
            const VoxelItem& referencePoint, 
            const vector<set<int> >& exclusions,
            bool reportSymmetricPairs,
            float maxDistance, 
            float minDistance) const 
    {

        // Loop over neighboring voxels
        // TODO use more clever selection of neighboring voxels

        const int atomI = referencePoint.second;
        const float* locationI = referencePoint.first;
        
        float maxDistanceSquared = maxDistance * maxDistance;
        float minDistanceSquared = minDistance * minDistance;

        int dIndexX = int(maxDistance / voxelSizeX) + 1; // How may voxels away do we have to look?
        int dIndexY = int(maxDistance / voxelSizeY) + 1;
        int dIndexZ = int(maxDistance / voxelSizeZ) + 1;
        VoxelIndex centerVoxelIndex = getVoxelIndex(locationI);
        int lastx = centerVoxelIndex.x+dIndexX;
        int lasty = centerVoxelIndex.y+dIndexY;
        int lastz = centerVoxelIndex.z+dIndexZ;
        if (usePeriodic) {
            lastx = min(lastx, centerVoxelIndex.x-dIndexX+nx-1);
            lasty = min(lasty, centerVoxelIndex.y-dIndexY+ny-1);
            lastz = min(lastz, centerVoxelIndex.z-dIndexZ+nz-1);
        }
        for (int x = centerVoxelIndex.x - dIndexX; x <= lastx; ++x)
        {
            for (int y = centerVoxelIndex.y - dIndexY; y <= lasty; ++y)
            {
                for (int z = centerVoxelIndex.z - dIndexZ; z <= lastz; ++z)
                {
                    VoxelIndex voxelIndex(x, y, z);
                    if (usePeriodic) {
                        voxelIndex.x = (x+nx)%nx;
                        voxelIndex.y = (y+ny)%ny;
                        voxelIndex.z = (z+nz)%nz;
                    }
                    if (voxelMap.find(voxelIndex) == voxelMap.end()) continue; // no such voxel; skip
                    const Voxel& voxel = voxelMap.find(voxelIndex)->second;
                    for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter)
                    {
                        const int atomJ = itemIter->second;
                        const float* locationJ = itemIter->first;
                        
                        // Ignore self hits
                        if (atomI == atomJ) continue;
                        
                        // Ignore exclusions.
                        if (exclusions[atomI].find(atomJ) != exclusions[atomI].end()) continue;
                        
                        float dSquared = compPairDistanceSquared(locationI, locationJ, periodicBoxSize, usePeriodic);
                        if (dSquared > maxDistanceSquared) continue;
                        if (dSquared < minDistanceSquared) continue;
                        
                        neighbors.push_back(make_pair(atomI, atomJ));
                        if (reportSymmetricPairs)
                            neighbors.push_back(make_pair(atomJ, atomI));
                    }
                }
            }
        }
    }

private:
    float voxelSizeX, voxelSizeY, voxelSizeZ;
    int nx, ny, nz;
    const float* periodicBoxSize;
    const bool usePeriodic;
    std::map<VoxelIndex, Voxel> voxelMap;
};


// O(n) neighbor list method using voxel hash data structure
void CpuNeighborList::computeNeighborList(
                              int nAtoms,
                              const vector<float>& atomLocations, 
                              const vector<set<int> >& exclusions,
                              const float* periodicBoxSize,
                              bool usePeriodic,
                              float maxDistance,
                              float minDistance,
                              bool reportSymmetricPairs)
{
    neighbors.clear();

    float edgeSizeX, edgeSizeY, edgeSizeZ;
    if (!usePeriodic)
        edgeSizeX = edgeSizeY = edgeSizeZ = maxDistance; // TODO - adjust this as needed
    else {
        edgeSizeX = periodicBoxSize[0]/floorf(periodicBoxSize[0]/maxDistance);
        edgeSizeY = periodicBoxSize[1]/floorf(periodicBoxSize[1]/maxDistance);
        edgeSizeZ = periodicBoxSize[2]/floorf(periodicBoxSize[2]/maxDistance);
    }
    VoxelHash voxelHash(edgeSizeX, edgeSizeY, edgeSizeZ, periodicBoxSize, usePeriodic);
    for (int atomJ = 0; atomJ < (int) nAtoms; ++atomJ) // use "j", because j > i for pairs
    {
        // 1) Find other atoms that are close to this one
        const float location[3] = {atomLocations[4*atomJ], atomLocations[4*atomJ+1], atomLocations[4*atomJ+2]};
        voxelHash.getNeighbors(
            neighbors, 
            VoxelItem(location, atomJ),
            exclusions,
            reportSymmetricPairs, 
            maxDistance, 
            minDistance);
            
        // 2) Add this atom to the voxelHash
        voxelHash.insert(atomJ, location);
    }
}

const vector<pair<int, int> >& CpuNeighborList::getNeighbors() {
    return neighbors;
}

} // namespace OpenMM