Commit 6ad86974 authored by peastman's avatar peastman
Browse files

Minor optimizations to neighbor list construction

parent 3dea0a62
...@@ -27,6 +27,14 @@ public: ...@@ -27,6 +27,14 @@ public:
*/ */
void runThread(int index); void runThread(int index);
private: private:
/**
* This is called by the worker threads to wait until the master thread instructs them to advance.
*/
void threadWait();
/**
* This is called by the master thread to instruct all the worker threads to advance.
*/
void advanceThreads();
bool isDeleted; bool isDeleted;
int numThreads, waitCount; int numThreads, waitCount;
std::vector<int> sortedAtoms; std::vector<int> sortedAtoms;
...@@ -37,6 +45,8 @@ private: ...@@ -37,6 +45,8 @@ private:
pthread_cond_t startCondition, endCondition; pthread_cond_t startCondition, endCondition;
pthread_mutex_t lock; pthread_mutex_t lock;
// The following variables are used to make information accessible to the individual threads. // The following variables are used to make information accessible to the individual threads.
float minx, maxx, miny, maxy, minz, maxz;
std::vector<std::pair<int, int> > atomBins;
Voxels* voxels; Voxels* voxels;
const std::vector<std::set<int> >* exclusions; const std::vector<std::set<int> >* exclusions;
const float* atomLocations; const float* atomLocations;
......
...@@ -310,38 +310,36 @@ void CpuNeighborList::computeNeighborList(int numAtoms, const vector<float>& ato ...@@ -310,38 +310,36 @@ void CpuNeighborList::computeNeighborList(int numAtoms, const vector<float>& ato
blockExclusions.resize(numBlocks); blockExclusions.resize(numBlocks);
sortedAtoms.resize(numAtoms); sortedAtoms.resize(numAtoms);
// Sort the atoms based on a Hilbert curve. // Record the parameters for the threads.
float minx = atomLocations[0], maxx = atomLocations[0]; this->exclusions = &exclusions;
float miny = atomLocations[1], maxy = atomLocations[1]; this->atomLocations = &atomLocations[0];
float minz = atomLocations[2], maxz = atomLocations[2]; this->periodicBoxSize = periodicBoxSize;
for (int i = 0; i < numAtoms; i++) { this->numAtoms = numAtoms;
const float* pos = &atomLocations[4*i]; this->usePeriodic = usePeriodic;
if (pos[0] < minx) this->maxDistance = maxDistance;
minx = pos[0];
if (pos[1] < miny) // Identify the range of atom positions along each axis.
miny = pos[1];
if (pos[2] < minz) fvec4 minPos(&atomLocations[0]);
minz = pos[2]; fvec4 maxPos = minPos;
if (pos[0] > maxx)
maxx = pos[0];
if (pos[1] > maxy)
maxy = pos[1];
if (pos[2] > maxz)
maxz = pos[2];
}
float binWidth = max(max(maxx-minx, maxy-miny), maxz-minz)/255.0f;
float invBinWidth = 1.0f/binWidth;
vector<pair<int, int> > atomBins(numAtoms);
bitmask_t coords[3];
for (int i = 0; i < numAtoms; i++) { for (int i = 0; i < numAtoms; i++) {
const float* pos = &atomLocations[4*i]; fvec4 pos(&atomLocations[4*i]);
coords[0] = (bitmask_t) ((pos[0]-minx)*invBinWidth); minPos = min(minPos, pos);
coords[1] = (bitmask_t) ((pos[1]-miny)*invBinWidth); maxPos = max(maxPos, pos);
coords[2] = (bitmask_t) ((pos[2]-minz)*invBinWidth);
int bin = (int) hilbert_c2i(3, 8, coords);
atomBins[i] = pair<int, int>(bin, i);
} }
minx = minPos[0];
maxx = maxPos[0];
miny = minPos[1];
maxy = maxPos[1];
minz = minPos[2];
maxz = maxPos[2];
// Sort the atoms based on a Hilbert curve.
atomBins.resize(numAtoms);
pthread_mutex_lock(&lock);
advanceThreads();
sort(atomBins.begin(), atomBins.end()); sort(atomBins.begin(), atomBins.end());
// Build the voxel hash. // Build the voxel hash.
...@@ -360,24 +358,11 @@ void CpuNeighborList::computeNeighborList(int numAtoms, const vector<float>& ato ...@@ -360,24 +358,11 @@ void CpuNeighborList::computeNeighborList(int numAtoms, const vector<float>& ato
voxels.insert(i, &atomLocations[4*atomIndex]); voxels.insert(i, &atomLocations[4*atomIndex]);
} }
voxels.sortItems(); voxels.sortItems();
// Record the parameters for the threads.
this->voxels = &voxels; this->voxels = &voxels;
this->exclusions = &exclusions;
this->atomLocations = &atomLocations[0];
this->periodicBoxSize = periodicBoxSize;
this->numAtoms = numAtoms;
this->usePeriodic = usePeriodic;
this->maxDistance = maxDistance;
// Signal the threads to start running and wait for them to finish. // Signal the threads to start running and wait for them to finish.
pthread_mutex_lock(&lock); advanceThreads();
waitCount = 0;
pthread_cond_broadcast(&startCondition);
while (waitCount < numThreads)
pthread_cond_wait(&endCondition, &lock);
pthread_mutex_unlock(&lock); pthread_mutex_unlock(&lock);
// Add padding atoms to fill up the last block. // Add padding atoms to fill up the last block.
...@@ -414,14 +399,25 @@ void CpuNeighborList::runThread(int index) { ...@@ -414,14 +399,25 @@ void CpuNeighborList::runThread(int index) {
while (true) { while (true) {
// Wait for the signal to start running. // Wait for the signal to start running.
pthread_mutex_lock(&lock); threadWait();
waitCount++;
pthread_cond_signal(&endCondition);
pthread_cond_wait(&startCondition, &lock);
pthread_mutex_unlock(&lock);
if (isDeleted) if (isDeleted)
break; break;
// Compute the positions of atoms along the Hilbert curve.
float binWidth = max(max(maxx-minx, maxy-miny), maxz-minz)/255.0f;
float invBinWidth = 1.0f/binWidth;
bitmask_t coords[3];
for (int i = index; i < numAtoms; i += numThreads) {
const float* pos = &atomLocations[4*i];
coords[0] = (bitmask_t) ((pos[0]-minx)*invBinWidth);
coords[1] = (bitmask_t) ((pos[1]-miny)*invBinWidth);
coords[2] = (bitmask_t) ((pos[2]-minz)*invBinWidth);
int bin = (int) hilbert_c2i(3, 8, coords);
atomBins[i] = pair<int, int>(bin, i);
}
threadWait();
// Compute this thread's subset of neighbors. // Compute this thread's subset of neighbors.
int numBlocks = blockNeighbors.size(); int numBlocks = blockNeighbors.size();
...@@ -462,4 +458,20 @@ void CpuNeighborList::runThread(int index) { ...@@ -462,4 +458,20 @@ void CpuNeighborList::runThread(int index) {
} }
} }
void CpuNeighborList::threadWait() {
pthread_mutex_lock(&lock);
waitCount++;
pthread_cond_signal(&endCondition);
pthread_cond_wait(&startCondition, &lock);
pthread_mutex_unlock(&lock);
}
void CpuNeighborList::advanceThreads() {
waitCount = 0;
pthread_cond_broadcast(&startCondition);
while (waitCount < numThreads) {
pthread_cond_wait(&endCondition, &lock);
}
}
} // namespace OpenMM } // namespace OpenMM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment