CpuNeighborList.cpp 9.61 KB
Newer Older
1
2
3
4
#include "CpuNeighborList.h"
#include <set>
#include <map>
#include <cmath>
5
#include <smmintrin.h>
6
7
8
9
10
11
12
13

using namespace std;

namespace OpenMM {

class VoxelIndex 
{
public:
14
15
    VoxelIndex(int xx, int yy, int zz) : x(xx), y(yy), z(zz) {
    }
16
17
18
19
20
21
22
23
24
25

    // operator<() needed for map
    bool operator<(const VoxelIndex& other) const {
        if      (x < other.x) return true;
        else if (x > other.x) return false;
        else if (y < other.y) return true;
        else if (y > other.y) return false;
        else if (z < other.z) return true;
        else return false;
    }
26
    
27
28
29
30
31
    int x;
    int y;
    int z;
};

32
33
typedef pair<const float*, int> VoxelItem;
typedef vector< VoxelItem > Voxel;
34

35
class CpuNeighborList::VoxelHash {
36
37
38
39
40
41
42
43
44
45
public:
    VoxelHash(float vsx, float vsy, float vsz, const float* periodicBoxSize, bool usePeriodic) :
            voxelSizeX(vsx), voxelSizeY(vsy), voxelSizeZ(vsz), periodicBoxSize(periodicBoxSize), usePeriodic(usePeriodic) {
        if (usePeriodic) {
            nx = (int) floorf(periodicBoxSize[0]/voxelSizeX+0.5f);
            ny = (int) floorf(periodicBoxSize[1]/voxelSizeY+0.5f);
            nz = (int) floorf(periodicBoxSize[2]/voxelSizeZ+0.5f);
        }
    }

46
    void insert(const int& item, const float* location) {
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        VoxelIndex voxelIndex = getVoxelIndex(location);
        if (voxelMap.find(voxelIndex) == voxelMap.end()) voxelMap[voxelIndex] = Voxel(); 
        Voxel& voxel = voxelMap.find(voxelIndex)->second;
        voxel.push_back(VoxelItem(location, item));
    }


    VoxelIndex getVoxelIndex(const float* location) const {
        float xperiodic, yperiodic, zperiodic;
        if (!usePeriodic) {
            xperiodic = location[0];
            yperiodic = location[1];
            zperiodic = location[2];
        }
        else {
            xperiodic = location[0]-periodicBoxSize[0]*floorf(location[0]/periodicBoxSize[0]);
            yperiodic = location[1]-periodicBoxSize[1]*floorf(location[1]/periodicBoxSize[1]);
            zperiodic = location[2]-periodicBoxSize[2]*floorf(location[2]/periodicBoxSize[2]);
        }
        int x = int(floorf(xperiodic / voxelSizeX));
        int y = int(floorf(yperiodic / voxelSizeY));
        int z = int(floorf(zperiodic / voxelSizeZ));
        
        return VoxelIndex(x, y, z);
    }

73
    void getNeighbors(vector<pair<int, int> >& neighbors, const VoxelItem& referencePoint, const vector<set<int> >& exclusions, float maxDistance) const {
74
75
76
77
78
79

        // Loop over neighboring voxels
        // TODO use more clever selection of neighboring voxels

        const int atomI = referencePoint.second;
        const float* locationI = referencePoint.first;
80
81
82
83
        __m128 posI = _mm_loadu_ps(locationI);
        __m128 boxSize = _mm_set_ps(0, periodicBoxSize[2], periodicBoxSize[1], periodicBoxSize[0]);
        __m128 invBoxSize = _mm_set_ps(0, (1/periodicBoxSize[2]), (1/periodicBoxSize[1]), (1/periodicBoxSize[0]));
        __m128 half = _mm_set1_ps(0.5);
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        
        float maxDistanceSquared = maxDistance * maxDistance;

        int dIndexX = int(maxDistance / voxelSizeX) + 1; // How may voxels away do we have to look?
        int dIndexY = int(maxDistance / voxelSizeY) + 1;
        int dIndexZ = int(maxDistance / voxelSizeZ) + 1;
        VoxelIndex centerVoxelIndex = getVoxelIndex(locationI);
        int lastx = centerVoxelIndex.x+dIndexX;
        int lasty = centerVoxelIndex.y+dIndexY;
        int lastz = centerVoxelIndex.z+dIndexZ;
        if (usePeriodic) {
            lastx = min(lastx, centerVoxelIndex.x-dIndexX+nx-1);
            lasty = min(lasty, centerVoxelIndex.y-dIndexY+ny-1);
            lastz = min(lastz, centerVoxelIndex.z-dIndexZ+nz-1);
        }
99
        VoxelIndex voxelIndex(0, 0, 0);
100
        for (int x = centerVoxelIndex.x - dIndexX; x <= lastx; ++x) {
101
102
103
            voxelIndex.x = x;
            if (usePeriodic)
                voxelIndex.x = (x < 0 ? x+nx : (x >= nx ? x-nx : x));
104
            for (int y = centerVoxelIndex.y - dIndexY; y <= lasty; ++y) {
105
106
107
                voxelIndex.y = y;
                if (usePeriodic)
                    voxelIndex.y = (y < 0 ? y+ny : (y >= ny ? y-ny : y));
108
                for (int z = centerVoxelIndex.z - dIndexZ; z <= lastz; ++z) {
109
110
111
                    voxelIndex.z = z;
                    if (usePeriodic)
                        voxelIndex.z = (z < 0 ? z+nz : (z >= nz ? z-nz : z));
112
                    const map<VoxelIndex, Voxel>::const_iterator voxelEntry = voxelMap.find(voxelIndex);
113
114
                    if (voxelEntry == voxelMap.end())
                        continue; // no such voxel; skip
115
116
                    const Voxel& voxel = voxelEntry->second;
                    for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter) {
117
                        const int atomJ = itemIter->second;
118

119
                        // Ignore self hits
120
                        if (atomI >= atomJ)
121
                            continue;
122
123
                        
                        // Ignore exclusions.
124
125
                        if (exclusions[atomI].find(atomJ) != exclusions[atomI].end())
                            continue;
126
                        
127
128
129
130
131
132
133
134
135
                        __m128 posJ = _mm_loadu_ps(itemIter->first);
                        __m128 delta = _mm_sub_ps(posJ, posI);
                        if (usePeriodic) {
                            __m128 base = _mm_mul_ps(_mm_floor_ps(_mm_add_ps(_mm_mul_ps(delta, invBoxSize), half)), boxSize);
                            delta = _mm_sub_ps(delta, base);
                        }
                        float dSquared = _mm_cvtss_f32(_mm_dp_ps(delta, delta, 0x71));
                        if (dSquared > maxDistanceSquared)
                            continue;
136
137
138
139
140
141
142
143
144
145
146
147
                        neighbors.push_back(make_pair(atomI, atomJ));
                    }
                }
            }
        }
    }

private:
    float voxelSizeX, voxelSizeY, voxelSizeZ;
    int nx, ny, nz;
    const float* periodicBoxSize;
    const bool usePeriodic;
148
    map<VoxelIndex, Voxel> voxelMap;
149
150
};

151
152
153
154
155
156
157
158
class CpuNeighborList::ThreadData {
public:
    ThreadData(int index, CpuNeighborList& owner) : index(index), owner(owner) {
    }
    int index;
    CpuNeighborList& owner;
    vector<pair<int, int> > threadNeighbors;
};
159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
static void* threadBody(void* args) {
    CpuNeighborList::ThreadData& data = *reinterpret_cast<CpuNeighborList::ThreadData*>(args);
    data.owner.runThread(data.index, data.threadNeighbors);
    delete &data;
    return 0;
}

CpuNeighborList::CpuNeighborList() {
    isDeleted = false;
    numThreads = 4;
    pthread_cond_init(&startCondition, NULL);
    pthread_cond_init(&endCondition, NULL);
    pthread_mutex_init(&lock, NULL);
    thread.resize(numThreads);
    for (int i = 0; i < numThreads; i++) {
        ThreadData* data = new ThreadData(i, *this);
        threadData.push_back(data);
        pthread_create(&thread[i], NULL, threadBody, data);
    }
}

CpuNeighborList::~CpuNeighborList() {
    isDeleted = true;
    pthread_mutex_lock(&lock);
    pthread_cond_broadcast(&startCondition);
    pthread_mutex_unlock(&lock);
    for (int i = 0; i < (int) thread.size(); i++)
        pthread_join(thread[i], NULL);
    pthread_mutex_destroy(&lock);
    pthread_cond_destroy(&startCondition);
    pthread_cond_destroy(&endCondition);
}
192

193
194
195
196
void CpuNeighborList::computeNeighborList(int numAtoms, const vector<float>& atomLocations, const vector<set<int> >& exclusions,
            const float* periodicBoxSize, bool usePeriodic, float maxDistance) {
    // Build the voxel hash.
    
197
198
199
200
    float edgeSizeX, edgeSizeY, edgeSizeZ;
    if (!usePeriodic)
        edgeSizeX = edgeSizeY = edgeSizeZ = maxDistance; // TODO - adjust this as needed
    else {
201
202
203
        edgeSizeX = 0.5f*periodicBoxSize[0]/floorf(periodicBoxSize[0]/maxDistance);
        edgeSizeY = 0.5f*periodicBoxSize[1]/floorf(periodicBoxSize[1]/maxDistance);
        edgeSizeZ = 0.5f*periodicBoxSize[2]/floorf(periodicBoxSize[2]/maxDistance);
204
205
    }
    VoxelHash voxelHash(edgeSizeX, edgeSizeY, edgeSizeZ, periodicBoxSize, usePeriodic);
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    for (int i = 0; i < numAtoms; i++)
        voxelHash.insert(i, &atomLocations[4*i]);
    
    // Record the parameters for the threads.
    
    this->voxelHash = &voxelHash;
    this->exclusions = &exclusions;
    this->atomLocations = &atomLocations[0];
    this->periodicBoxSize = periodicBoxSize;
    this->numAtoms = numAtoms;
    this->usePeriodic = usePeriodic;
    this->maxDistance = maxDistance;
    
    // Signal the threads to start running and wait for them to finish.
    
    pthread_mutex_lock(&lock);
    waitCount = 0;
    pthread_cond_broadcast(&startCondition);
    while (waitCount < numThreads)
        pthread_cond_wait(&endCondition, &lock);
    pthread_mutex_unlock(&lock);
    
    // Combine the results from all the threads.
    
    neighbors.clear();
    for (int i = 0; i < numThreads; i++)
        neighbors.insert(neighbors.end(), threadData[i]->threadNeighbors.begin(), threadData[i]->threadNeighbors.end());
233
234
235
236
237
238
}

const vector<pair<int, int> >& CpuNeighborList::getNeighbors() {
    return neighbors;
}

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
void CpuNeighborList::runThread(int index, vector<pair<int, int> >& threadNeighbors) {
    while (true) {
        // Wait for the signal to start running.
        
        pthread_mutex_lock(&lock);
        waitCount++;
        pthread_cond_signal(&endCondition);
        pthread_cond_wait(&startCondition, &lock);
        pthread_mutex_unlock(&lock);
        if (isDeleted)
            break;
        
        // Compute this thread's subset of neighbors.
        
        threadNeighbors.clear();
        for (int i = index; i < numAtoms; i += numThreads)
            voxelHash->getNeighbors(threadNeighbors, VoxelItem(&atomLocations[4*i], i), *exclusions, maxDistance);
    }
}

259
} // namespace OpenMM