CpuNeighborList.cpp 7.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include "CpuNeighborList.h"
#include <set>
#include <map>
#include <cmath>

using namespace std;

namespace OpenMM {

static float periodicDifference(float val1, float val2, float period) {
    float diff = val1-val2;
    float base = floorf(diff/period+0.5f)*period;
    return diff-base;
}

// squared distance between two points
static float compPairDistanceSquared(const float* pos1, const float* pos2, const float* periodicBoxSize, bool usePeriodic) {
    float dx, dy, dz;
    if (!usePeriodic) {
        dx = pos2[0] - pos1[0];
        dy = pos2[1] - pos1[1];
        dz = pos2[2] - pos1[2];
    }
    else {
        dx = periodicDifference(pos2[0], pos1[0], periodicBoxSize[0]);
        dy = periodicDifference(pos2[1], pos1[1], periodicBoxSize[1]);
        dz = periodicDifference(pos2[2], pos1[2], periodicBoxSize[2]);
    }
    return dx*dx + dy*dy + dz*dz;
}

class VoxelIndex 
{
public:
35
36
    VoxelIndex(int xx, int yy, int zz) : x(xx), y(yy), z(zz) {
    }
37
38
39
40
41
42
43
44
45
46

    // operator<() needed for map
    bool operator<(const VoxelIndex& other) const {
        if      (x < other.x) return true;
        else if (x > other.x) return false;
        else if (y < other.y) return true;
        else if (y > other.y) return false;
        else if (z < other.z) return true;
        else return false;
    }
47
    
48
49
50
51
52
    int x;
    int y;
    int z;
};

53
54
typedef pair<const float*, int> VoxelItem;
typedef vector< VoxelItem > Voxel;
55

56
class VoxelHash {
57
58
59
60
61
62
63
64
65
66
public:
    VoxelHash(float vsx, float vsy, float vsz, const float* periodicBoxSize, bool usePeriodic) :
            voxelSizeX(vsx), voxelSizeY(vsy), voxelSizeZ(vsz), periodicBoxSize(periodicBoxSize), usePeriodic(usePeriodic) {
        if (usePeriodic) {
            nx = (int) floorf(periodicBoxSize[0]/voxelSizeX+0.5f);
            ny = (int) floorf(periodicBoxSize[1]/voxelSizeY+0.5f);
            nz = (int) floorf(periodicBoxSize[2]/voxelSizeZ+0.5f);
        }
    }

67
    void insert(const int& item, const float* location) {
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        VoxelIndex voxelIndex = getVoxelIndex(location);
        if (voxelMap.find(voxelIndex) == voxelMap.end()) voxelMap[voxelIndex] = Voxel(); 
        Voxel& voxel = voxelMap.find(voxelIndex)->second;
        voxel.push_back(VoxelItem(location, item));
    }


    VoxelIndex getVoxelIndex(const float* location) const {
        float xperiodic, yperiodic, zperiodic;
        if (!usePeriodic) {
            xperiodic = location[0];
            yperiodic = location[1];
            zperiodic = location[2];
        }
        else {
            xperiodic = location[0]-periodicBoxSize[0]*floorf(location[0]/periodicBoxSize[0]);
            yperiodic = location[1]-periodicBoxSize[1]*floorf(location[1]/periodicBoxSize[1]);
            zperiodic = location[2]-periodicBoxSize[2]*floorf(location[2]/periodicBoxSize[2]);
        }
        int x = int(floorf(xperiodic / voxelSizeX));
        int y = int(floorf(yperiodic / voxelSizeY));
        int z = int(floorf(zperiodic / voxelSizeZ));
        
        return VoxelIndex(x, y, z);
    }

94
    void getNeighbors(vector<pair<int, int> >& neighbors, const VoxelItem& referencePoint, const vector<set<int> >& exclusions, float maxDistance) const {
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

        // Loop over neighboring voxels
        // TODO use more clever selection of neighboring voxels

        const int atomI = referencePoint.second;
        const float* locationI = referencePoint.first;
        
        float maxDistanceSquared = maxDistance * maxDistance;

        int dIndexX = int(maxDistance / voxelSizeX) + 1; // How may voxels away do we have to look?
        int dIndexY = int(maxDistance / voxelSizeY) + 1;
        int dIndexZ = int(maxDistance / voxelSizeZ) + 1;
        VoxelIndex centerVoxelIndex = getVoxelIndex(locationI);
        int lastx = centerVoxelIndex.x+dIndexX;
        int lasty = centerVoxelIndex.y+dIndexY;
        int lastz = centerVoxelIndex.z+dIndexZ;
        if (usePeriodic) {
            lastx = min(lastx, centerVoxelIndex.x-dIndexX+nx-1);
            lasty = min(lasty, centerVoxelIndex.y-dIndexY+ny-1);
            lastz = min(lastz, centerVoxelIndex.z-dIndexZ+nz-1);
        }
116
117
118
        for (int x = centerVoxelIndex.x - dIndexX; x <= lastx; ++x) {
            for (int y = centerVoxelIndex.y - dIndexY; y <= lasty; ++y) {
                for (int z = centerVoxelIndex.z - dIndexZ; z <= lastz; ++z) {
119
120
121
122
123
124
                    VoxelIndex voxelIndex(x, y, z);
                    if (usePeriodic) {
                        voxelIndex.x = (x+nx)%nx;
                        voxelIndex.y = (y+ny)%ny;
                        voxelIndex.z = (z+nz)%nz;
                    }
125
126
127
128
                    const map<VoxelIndex, Voxel>::const_iterator voxelEntry = voxelMap.find(voxelIndex);
                    if (voxelEntry == voxelMap.end()) continue; // no such voxel; skip
                    const Voxel& voxel = voxelEntry->second;
                    for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter) {
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
                        const int atomJ = itemIter->second;
                        const float* locationJ = itemIter->first;
                        
                        // Ignore self hits
                        if (atomI == atomJ) continue;
                        
                        // Ignore exclusions.
                        if (exclusions[atomI].find(atomJ) != exclusions[atomI].end()) continue;
                        
                        float dSquared = compPairDistanceSquared(locationI, locationJ, periodicBoxSize, usePeriodic);
                        if (dSquared > maxDistanceSquared) continue;
                        
                        neighbors.push_back(make_pair(atomI, atomJ));
                    }
                }
            }
        }
    }

private:
    float voxelSizeX, voxelSizeY, voxelSizeZ;
    int nx, ny, nz;
    const float* periodicBoxSize;
    const bool usePeriodic;
153
    map<VoxelIndex, Voxel> voxelMap;
154
155
156
157
};


// O(n) neighbor list method using voxel hash data structure
158
159
void CpuNeighborList::computeNeighborList(int nAtoms, const vector<float>& atomLocations, const vector<set<int> >& exclusions,
            const float* periodicBoxSize, bool usePeriodic, float maxDistance) {
160
161
162
163
164
165
    neighbors.clear();

    float edgeSizeX, edgeSizeY, edgeSizeZ;
    if (!usePeriodic)
        edgeSizeX = edgeSizeY = edgeSizeZ = maxDistance; // TODO - adjust this as needed
    else {
166
167
168
        edgeSizeX = 0.5f*periodicBoxSize[0]/floorf(periodicBoxSize[0]/maxDistance);
        edgeSizeY = 0.5f*periodicBoxSize[1]/floorf(periodicBoxSize[1]/maxDistance);
        edgeSizeZ = 0.5f*periodicBoxSize[2]/floorf(periodicBoxSize[2]/maxDistance);
169
170
    }
    VoxelHash voxelHash(edgeSizeX, edgeSizeY, edgeSizeZ, periodicBoxSize, usePeriodic);
171
    for (int atomJ = 0; atomJ < (int) nAtoms; ++atomJ) { // use "j", because j > i for pairs
172
        // 1) Find other atoms that are close to this one
173
        const float* location = &atomLocations[4*atomJ];
174
175
176
177
        voxelHash.getNeighbors(
            neighbors, 
            VoxelItem(location, atomJ),
            exclusions,
178
            maxDistance);
179
180
181
182
183
184
185
186
187
188
189
            
        // 2) Add this atom to the voxelHash
        voxelHash.insert(atomJ, location);
    }
}

const vector<pair<int, int> >& CpuNeighborList::getNeighbors() {
    return neighbors;
}

} // namespace OpenMM