CpuNeighborList.cpp 9.66 KB
Newer Older
1
#include "CpuNeighborList.h"
2
#include "openmm/internal/hardware.h"
3
4
5
#include <set>
#include <map>
#include <cmath>
6
#include <smmintrin.h>
7
8
9
10
11
12
13
14

using namespace std;

namespace OpenMM {

class VoxelIndex 
{
public:
15
16
    VoxelIndex(int xx, int yy, int zz) : x(xx), y(yy), z(zz) {
    }
17
18
19
20
21
22
23
24
25
26

    // operator<() needed for map
    bool operator<(const VoxelIndex& other) const {
        if      (x < other.x) return true;
        else if (x > other.x) return false;
        else if (y < other.y) return true;
        else if (y > other.y) return false;
        else if (z < other.z) return true;
        else return false;
    }
27
    
28
29
30
31
32
    int x;
    int y;
    int z;
};

33
34
typedef pair<const float*, int> VoxelItem;
typedef vector< VoxelItem > Voxel;
35

36
class CpuNeighborList::VoxelHash {
37
38
39
40
41
42
43
44
45
46
public:
    VoxelHash(float vsx, float vsy, float vsz, const float* periodicBoxSize, bool usePeriodic) :
            voxelSizeX(vsx), voxelSizeY(vsy), voxelSizeZ(vsz), periodicBoxSize(periodicBoxSize), usePeriodic(usePeriodic) {
        if (usePeriodic) {
            nx = (int) floorf(periodicBoxSize[0]/voxelSizeX+0.5f);
            ny = (int) floorf(periodicBoxSize[1]/voxelSizeY+0.5f);
            nz = (int) floorf(periodicBoxSize[2]/voxelSizeZ+0.5f);
        }
    }

47
    void insert(const int& item, const float* location) {
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        VoxelIndex voxelIndex = getVoxelIndex(location);
        if (voxelMap.find(voxelIndex) == voxelMap.end()) voxelMap[voxelIndex] = Voxel(); 
        Voxel& voxel = voxelMap.find(voxelIndex)->second;
        voxel.push_back(VoxelItem(location, item));
    }


    VoxelIndex getVoxelIndex(const float* location) const {
        float xperiodic, yperiodic, zperiodic;
        if (!usePeriodic) {
            xperiodic = location[0];
            yperiodic = location[1];
            zperiodic = location[2];
        }
        else {
            xperiodic = location[0]-periodicBoxSize[0]*floorf(location[0]/periodicBoxSize[0]);
            yperiodic = location[1]-periodicBoxSize[1]*floorf(location[1]/periodicBoxSize[1]);
            zperiodic = location[2]-periodicBoxSize[2]*floorf(location[2]/periodicBoxSize[2]);
        }
        int x = int(floorf(xperiodic / voxelSizeX));
        int y = int(floorf(yperiodic / voxelSizeY));
        int z = int(floorf(zperiodic / voxelSizeZ));
        
        return VoxelIndex(x, y, z);
    }

74
    void getNeighbors(vector<pair<int, int> >& neighbors, const VoxelItem& referencePoint, const vector<set<int> >& exclusions, float maxDistance) const {
75
76
77
78
79
80

        // Loop over neighboring voxels
        // TODO use more clever selection of neighboring voxels

        const int atomI = referencePoint.second;
        const float* locationI = referencePoint.first;
81
82
83
84
        __m128 posI = _mm_loadu_ps(locationI);
        __m128 boxSize = _mm_set_ps(0, periodicBoxSize[2], periodicBoxSize[1], periodicBoxSize[0]);
        __m128 invBoxSize = _mm_set_ps(0, (1/periodicBoxSize[2]), (1/periodicBoxSize[1]), (1/periodicBoxSize[0]));
        __m128 half = _mm_set1_ps(0.5);
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        
        float maxDistanceSquared = maxDistance * maxDistance;

        int dIndexX = int(maxDistance / voxelSizeX) + 1; // How may voxels away do we have to look?
        int dIndexY = int(maxDistance / voxelSizeY) + 1;
        int dIndexZ = int(maxDistance / voxelSizeZ) + 1;
        VoxelIndex centerVoxelIndex = getVoxelIndex(locationI);
        int lastx = centerVoxelIndex.x+dIndexX;
        int lasty = centerVoxelIndex.y+dIndexY;
        int lastz = centerVoxelIndex.z+dIndexZ;
        if (usePeriodic) {
            lastx = min(lastx, centerVoxelIndex.x-dIndexX+nx-1);
            lasty = min(lasty, centerVoxelIndex.y-dIndexY+ny-1);
            lastz = min(lastz, centerVoxelIndex.z-dIndexZ+nz-1);
        }
100
        VoxelIndex voxelIndex(0, 0, 0);
101
        for (int x = centerVoxelIndex.x - dIndexX; x <= lastx; ++x) {
102
103
104
            voxelIndex.x = x;
            if (usePeriodic)
                voxelIndex.x = (x < 0 ? x+nx : (x >= nx ? x-nx : x));
105
            for (int y = centerVoxelIndex.y - dIndexY; y <= lasty; ++y) {
106
107
108
                voxelIndex.y = y;
                if (usePeriodic)
                    voxelIndex.y = (y < 0 ? y+ny : (y >= ny ? y-ny : y));
109
                for (int z = centerVoxelIndex.z - dIndexZ; z <= lastz; ++z) {
110
111
112
                    voxelIndex.z = z;
                    if (usePeriodic)
                        voxelIndex.z = (z < 0 ? z+nz : (z >= nz ? z-nz : z));
113
                    const map<VoxelIndex, Voxel>::const_iterator voxelEntry = voxelMap.find(voxelIndex);
114
115
                    if (voxelEntry == voxelMap.end())
                        continue; // no such voxel; skip
116
117
                    const Voxel& voxel = voxelEntry->second;
                    for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter) {
118
                        const int atomJ = itemIter->second;
119

120
                        // Ignore self hits
121
                        if (atomI >= atomJ)
122
                            continue;
123
124
                        
                        // Ignore exclusions.
125
126
                        if (exclusions[atomI].find(atomJ) != exclusions[atomI].end())
                            continue;
127
                        
128
129
130
131
132
133
134
135
136
                        __m128 posJ = _mm_loadu_ps(itemIter->first);
                        __m128 delta = _mm_sub_ps(posJ, posI);
                        if (usePeriodic) {
                            __m128 base = _mm_mul_ps(_mm_floor_ps(_mm_add_ps(_mm_mul_ps(delta, invBoxSize), half)), boxSize);
                            delta = _mm_sub_ps(delta, base);
                        }
                        float dSquared = _mm_cvtss_f32(_mm_dp_ps(delta, delta, 0x71));
                        if (dSquared > maxDistanceSquared)
                            continue;
137
138
139
140
141
142
143
144
145
146
147
148
                        neighbors.push_back(make_pair(atomI, atomJ));
                    }
                }
            }
        }
    }

private:
    float voxelSizeX, voxelSizeY, voxelSizeZ;
    int nx, ny, nz;
    const float* periodicBoxSize;
    const bool usePeriodic;
149
    map<VoxelIndex, Voxel> voxelMap;
150
151
};

152
153
154
155
156
157
158
159
class CpuNeighborList::ThreadData {
public:
    ThreadData(int index, CpuNeighborList& owner) : index(index), owner(owner) {
    }
    int index;
    CpuNeighborList& owner;
    vector<pair<int, int> > threadNeighbors;
};
160

161
162
163
164
165
166
167
168
169
static void* threadBody(void* args) {
    CpuNeighborList::ThreadData& data = *reinterpret_cast<CpuNeighborList::ThreadData*>(args);
    data.owner.runThread(data.index, data.threadNeighbors);
    delete &data;
    return 0;
}

CpuNeighborList::CpuNeighborList() {
    isDeleted = false;
170
    numThreads = getNumProcessors();
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    pthread_cond_init(&startCondition, NULL);
    pthread_cond_init(&endCondition, NULL);
    pthread_mutex_init(&lock, NULL);
    thread.resize(numThreads);
    for (int i = 0; i < numThreads; i++) {
        ThreadData* data = new ThreadData(i, *this);
        threadData.push_back(data);
        pthread_create(&thread[i], NULL, threadBody, data);
    }
}

CpuNeighborList::~CpuNeighborList() {
    isDeleted = true;
    pthread_mutex_lock(&lock);
    pthread_cond_broadcast(&startCondition);
    pthread_mutex_unlock(&lock);
    for (int i = 0; i < (int) thread.size(); i++)
        pthread_join(thread[i], NULL);
    pthread_mutex_destroy(&lock);
    pthread_cond_destroy(&startCondition);
    pthread_cond_destroy(&endCondition);
}
193

194
195
196
197
void CpuNeighborList::computeNeighborList(int numAtoms, const vector<float>& atomLocations, const vector<set<int> >& exclusions,
            const float* periodicBoxSize, bool usePeriodic, float maxDistance) {
    // Build the voxel hash.
    
198
199
200
201
    float edgeSizeX, edgeSizeY, edgeSizeZ;
    if (!usePeriodic)
        edgeSizeX = edgeSizeY = edgeSizeZ = maxDistance; // TODO - adjust this as needed
    else {
202
203
204
        edgeSizeX = 0.5f*periodicBoxSize[0]/floorf(periodicBoxSize[0]/maxDistance);
        edgeSizeY = 0.5f*periodicBoxSize[1]/floorf(periodicBoxSize[1]/maxDistance);
        edgeSizeZ = 0.5f*periodicBoxSize[2]/floorf(periodicBoxSize[2]/maxDistance);
205
206
    }
    VoxelHash voxelHash(edgeSizeX, edgeSizeY, edgeSizeZ, periodicBoxSize, usePeriodic);
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    for (int i = 0; i < numAtoms; i++)
        voxelHash.insert(i, &atomLocations[4*i]);
    
    // Record the parameters for the threads.
    
    this->voxelHash = &voxelHash;
    this->exclusions = &exclusions;
    this->atomLocations = &atomLocations[0];
    this->periodicBoxSize = periodicBoxSize;
    this->numAtoms = numAtoms;
    this->usePeriodic = usePeriodic;
    this->maxDistance = maxDistance;
    
    // Signal the threads to start running and wait for them to finish.
    
    pthread_mutex_lock(&lock);
    waitCount = 0;
    pthread_cond_broadcast(&startCondition);
    while (waitCount < numThreads)
        pthread_cond_wait(&endCondition, &lock);
    pthread_mutex_unlock(&lock);
    
    // Combine the results from all the threads.
    
    neighbors.clear();
    for (int i = 0; i < numThreads; i++)
        neighbors.insert(neighbors.end(), threadData[i]->threadNeighbors.begin(), threadData[i]->threadNeighbors.end());
234
235
236
237
238
239
}

const vector<pair<int, int> >& CpuNeighborList::getNeighbors() {
    return neighbors;
}

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
void CpuNeighborList::runThread(int index, vector<pair<int, int> >& threadNeighbors) {
    while (true) {
        // Wait for the signal to start running.
        
        pthread_mutex_lock(&lock);
        waitCount++;
        pthread_cond_signal(&endCondition);
        pthread_cond_wait(&startCondition, &lock);
        pthread_mutex_unlock(&lock);
        if (isDeleted)
            break;
        
        // Compute this thread's subset of neighbors.
        
        threadNeighbors.clear();
        for (int i = index; i < numAtoms; i += numThreads)
            voxelHash->getNeighbors(threadNeighbors, VoxelItem(&atomLocations[4*i], i), *exclusions, maxDistance);
    }
}

260
} // namespace OpenMM