Commit d4d9ce1f authored by peastman's avatar peastman
Browse files

Optimizations to neighbor list construction

parent d4a343e2
...@@ -15,9 +15,7 @@ public: ...@@ -15,9 +15,7 @@ public:
const std::vector<std::set<int> >& exclusions, const std::vector<std::set<int> >& exclusions,
const float* periodicBoxSize, const float* periodicBoxSize,
bool usePeriodic, bool usePeriodic,
float maxDistance, float maxDistance);
float minDistance = 0.0f,
bool reportSymmetricPairs = false);
const std::vector<std::pair<int, int> >& getNeighbors(); const std::vector<std::pair<int, int> >& getNeighbors();
private: private:
std::vector<std::pair<int, int> > neighbors; std::vector<std::pair<int, int> > neighbors;
......
...@@ -155,7 +155,7 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -155,7 +155,7 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
posq[4*i+1] = (float) posData[i][1]; posq[4*i+1] = (float) posData[i][1];
posq[4*i+2] = (float) posData[i][2]; posq[4*i+2] = (float) posData[i][2];
} }
neighborList.computeNeighborList(numParticles, posq, exclusions, floatBoxSize, periodic || ewald || pme, nonbondedCutoff, 0.0); neighborList.computeNeighborList(numParticles, posq, exclusions, floatBoxSize, periodic || ewald || pme, nonbondedCutoff);
// if (nonbondedMethod != NoCutoff) { // if (nonbondedMethod != NoCutoff) {
// computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, extractBoxSize(context), periodic || ewald || pme, nonbondedCutoff, 0.0); // computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, extractBoxSize(context), periodic || ewald || pme, nonbondedCutoff, 0.0);
// clj.setUseCutoff(nonbondedCutoff, *neighborList, rfDielectric); // clj.setUseCutoff(nonbondedCutoff, *neighborList, rfDielectric);
......
...@@ -32,7 +32,8 @@ static float compPairDistanceSquared(const float* pos1, const float* pos2, const ...@@ -32,7 +32,8 @@ static float compPairDistanceSquared(const float* pos1, const float* pos2, const
class VoxelIndex class VoxelIndex
{ {
public: public:
VoxelIndex(int xx, int yy, int zz) : x(xx), y(yy), z(zz) {} VoxelIndex(int xx, int yy, int zz) : x(xx), y(yy), z(zz) {
}
// operator<() needed for map // operator<() needed for map
bool operator<(const VoxelIndex& other) const { bool operator<(const VoxelIndex& other) const {
...@@ -49,11 +50,10 @@ public: ...@@ -49,11 +50,10 @@ public:
int z; int z;
}; };
typedef std::pair<const float*, int> VoxelItem; typedef pair<const float*, int> VoxelItem;
typedef std::vector< VoxelItem > Voxel; typedef vector< VoxelItem > Voxel;
class VoxelHash class VoxelHash {
{
public: public:
VoxelHash(float vsx, float vsy, float vsz, const float* periodicBoxSize, bool usePeriodic) : VoxelHash(float vsx, float vsy, float vsz, const float* periodicBoxSize, bool usePeriodic) :
voxelSizeX(vsx), voxelSizeY(vsy), voxelSizeZ(vsz), periodicBoxSize(periodicBoxSize), usePeriodic(usePeriodic) { voxelSizeX(vsx), voxelSizeY(vsy), voxelSizeZ(vsz), periodicBoxSize(periodicBoxSize), usePeriodic(usePeriodic) {
...@@ -64,8 +64,7 @@ public: ...@@ -64,8 +64,7 @@ public:
} }
} }
void insert(const int& item, const float* location) void insert(const int& item, const float* location) {
{
VoxelIndex voxelIndex = getVoxelIndex(location); VoxelIndex voxelIndex = getVoxelIndex(location);
if (voxelMap.find(voxelIndex) == voxelMap.end()) voxelMap[voxelIndex] = Voxel(); if (voxelMap.find(voxelIndex) == voxelMap.end()) voxelMap[voxelIndex] = Voxel();
Voxel& voxel = voxelMap.find(voxelIndex)->second; Voxel& voxel = voxelMap.find(voxelIndex)->second;
...@@ -92,14 +91,7 @@ public: ...@@ -92,14 +91,7 @@ public:
return VoxelIndex(x, y, z); return VoxelIndex(x, y, z);
} }
void getNeighbors( void getNeighbors(vector<pair<int, int> >& neighbors, const VoxelItem& referencePoint, const vector<set<int> >& exclusions, float maxDistance) const {
vector<pair<int, int> >& neighbors,
const VoxelItem& referencePoint,
const vector<set<int> >& exclusions,
bool reportSymmetricPairs,
float maxDistance,
float minDistance) const
{
// Loop over neighboring voxels // Loop over neighboring voxels
// TODO use more clever selection of neighboring voxels // TODO use more clever selection of neighboring voxels
...@@ -108,7 +100,6 @@ public: ...@@ -108,7 +100,6 @@ public:
const float* locationI = referencePoint.first; const float* locationI = referencePoint.first;
float maxDistanceSquared = maxDistance * maxDistance; float maxDistanceSquared = maxDistance * maxDistance;
float minDistanceSquared = minDistance * minDistance;
int dIndexX = int(maxDistance / voxelSizeX) + 1; // How may voxels away do we have to look? int dIndexX = int(maxDistance / voxelSizeX) + 1; // How may voxels away do we have to look?
int dIndexY = int(maxDistance / voxelSizeY) + 1; int dIndexY = int(maxDistance / voxelSizeY) + 1;
...@@ -122,22 +113,19 @@ public: ...@@ -122,22 +113,19 @@ public:
lasty = min(lasty, centerVoxelIndex.y-dIndexY+ny-1); lasty = min(lasty, centerVoxelIndex.y-dIndexY+ny-1);
lastz = min(lastz, centerVoxelIndex.z-dIndexZ+nz-1); lastz = min(lastz, centerVoxelIndex.z-dIndexZ+nz-1);
} }
for (int x = centerVoxelIndex.x - dIndexX; x <= lastx; ++x) for (int x = centerVoxelIndex.x - dIndexX; x <= lastx; ++x) {
{ for (int y = centerVoxelIndex.y - dIndexY; y <= lasty; ++y) {
for (int y = centerVoxelIndex.y - dIndexY; y <= lasty; ++y) for (int z = centerVoxelIndex.z - dIndexZ; z <= lastz; ++z) {
{
for (int z = centerVoxelIndex.z - dIndexZ; z <= lastz; ++z)
{
VoxelIndex voxelIndex(x, y, z); VoxelIndex voxelIndex(x, y, z);
if (usePeriodic) { if (usePeriodic) {
voxelIndex.x = (x+nx)%nx; voxelIndex.x = (x+nx)%nx;
voxelIndex.y = (y+ny)%ny; voxelIndex.y = (y+ny)%ny;
voxelIndex.z = (z+nz)%nz; voxelIndex.z = (z+nz)%nz;
} }
if (voxelMap.find(voxelIndex) == voxelMap.end()) continue; // no such voxel; skip const map<VoxelIndex, Voxel>::const_iterator voxelEntry = voxelMap.find(voxelIndex);
const Voxel& voxel = voxelMap.find(voxelIndex)->second; if (voxelEntry == voxelMap.end()) continue; // no such voxel; skip
for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter) const Voxel& voxel = voxelEntry->second;
{ for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter) {
const int atomJ = itemIter->second; const int atomJ = itemIter->second;
const float* locationJ = itemIter->first; const float* locationJ = itemIter->first;
...@@ -149,11 +137,8 @@ public: ...@@ -149,11 +137,8 @@ public:
float dSquared = compPairDistanceSquared(locationI, locationJ, periodicBoxSize, usePeriodic); float dSquared = compPairDistanceSquared(locationI, locationJ, periodicBoxSize, usePeriodic);
if (dSquared > maxDistanceSquared) continue; if (dSquared > maxDistanceSquared) continue;
if (dSquared < minDistanceSquared) continue;
neighbors.push_back(make_pair(atomI, atomJ)); neighbors.push_back(make_pair(atomI, atomJ));
if (reportSymmetricPairs)
neighbors.push_back(make_pair(atomJ, atomI));
} }
} }
} }
...@@ -165,43 +150,32 @@ private: ...@@ -165,43 +150,32 @@ private:
int nx, ny, nz; int nx, ny, nz;
const float* periodicBoxSize; const float* periodicBoxSize;
const bool usePeriodic; const bool usePeriodic;
std::map<VoxelIndex, Voxel> voxelMap; map<VoxelIndex, Voxel> voxelMap;
}; };
// O(n) neighbor list method using voxel hash data structure // O(n) neighbor list method using voxel hash data structure
void CpuNeighborList::computeNeighborList( void CpuNeighborList::computeNeighborList(int nAtoms, const vector<float>& atomLocations, const vector<set<int> >& exclusions,
int nAtoms, const float* periodicBoxSize, bool usePeriodic, float maxDistance) {
const vector<float>& atomLocations,
const vector<set<int> >& exclusions,
const float* periodicBoxSize,
bool usePeriodic,
float maxDistance,
float minDistance,
bool reportSymmetricPairs)
{
neighbors.clear(); neighbors.clear();
float edgeSizeX, edgeSizeY, edgeSizeZ; float edgeSizeX, edgeSizeY, edgeSizeZ;
if (!usePeriodic) if (!usePeriodic)
edgeSizeX = edgeSizeY = edgeSizeZ = maxDistance; // TODO - adjust this as needed edgeSizeX = edgeSizeY = edgeSizeZ = maxDistance; // TODO - adjust this as needed
else { else {
edgeSizeX = periodicBoxSize[0]/floorf(periodicBoxSize[0]/maxDistance); edgeSizeX = 0.5f*periodicBoxSize[0]/floorf(periodicBoxSize[0]/maxDistance);
edgeSizeY = periodicBoxSize[1]/floorf(periodicBoxSize[1]/maxDistance); edgeSizeY = 0.5f*periodicBoxSize[1]/floorf(periodicBoxSize[1]/maxDistance);
edgeSizeZ = periodicBoxSize[2]/floorf(periodicBoxSize[2]/maxDistance); edgeSizeZ = 0.5f*periodicBoxSize[2]/floorf(periodicBoxSize[2]/maxDistance);
} }
VoxelHash voxelHash(edgeSizeX, edgeSizeY, edgeSizeZ, periodicBoxSize, usePeriodic); VoxelHash voxelHash(edgeSizeX, edgeSizeY, edgeSizeZ, periodicBoxSize, usePeriodic);
for (int atomJ = 0; atomJ < (int) nAtoms; ++atomJ) // use "j", because j > i for pairs for (int atomJ = 0; atomJ < (int) nAtoms; ++atomJ) { // use "j", because j > i for pairs
{
// 1) Find other atoms that are close to this one // 1) Find other atoms that are close to this one
const float location[3] = {atomLocations[4*atomJ], atomLocations[4*atomJ+1], atomLocations[4*atomJ+2]}; const float* location = &atomLocations[4*atomJ];
voxelHash.getNeighbors( voxelHash.getNeighbors(
neighbors, neighbors,
VoxelItem(location, atomJ), VoxelItem(location, atomJ),
exclusions, exclusions,
reportSymmetricPairs, maxDistance);
maxDistance,
minDistance);
// 2) Add this atom to the voxelHash // 2) Add this atom to the voxelHash
voxelHash.insert(atomJ, location); voxelHash.insert(atomJ, location);
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
/**
* This tests all the CPU implementation of neighbor list construction.
*/
#include "openmm/internal/AssertionUtilities.h"
#include "CpuNeighborList.h"
#include "sfmt/SFMT.h"
#include <iostream>
#include <set>
#include <vector>
using namespace OpenMM;
using namespace std;
void testNeighborList(bool periodic) {
const int numParticles = 500;
const float cutoff = 2.0f;
const float boxSize = 20.0f;
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
vector<float> positions(4*numParticles);
for (int i = 0; i < 4*numParticles; i++)
positions[i] = boxSize*genrand_real2(sfmt);
vector<set<int> > exclusions(numParticles);
for (int i = 0; i < numParticles; i++) {
int num = min(i+1, 10);
for (int j = 0; j < num; j++) {
exclusions[i].insert(i-j);
exclusions[i-j].insert(i);
}
}
CpuNeighborList neighborList;
float box[3] = {boxSize, boxSize, boxSize};
neighborList.computeNeighborList(numParticles, positions, exclusions, box, periodic, cutoff);
// Convert the neighbor list to a set for faster lookup.
set<pair<int, int> > neighbors;
for (int i = 0; i < (int) neighborList.getNeighbors().size(); i++) {
pair<int, int> entry = neighborList.getNeighbors()[i];
ASSERT(neighbors.find(entry) == neighbors.end() && neighbors.find(make_pair(entry.second, entry.first)) == neighbors.end()); // No duplicates
neighbors.insert(entry);
}
// Check each particle pair and figure out whether they should be in the neighbor list.
for (int i = 0; i < numParticles; i++)
for (int j = 0; j <= i; j++) {
bool shouldInclude = (exclusions[i].find(j) == exclusions[i].end());
float dx = positions[4*i]-positions[4*j];
float dy = positions[4*i+1]-positions[4*j+1];
float dz = positions[4*i+2]-positions[4*j+2];
if (periodic) {
dx -= floor(dx/boxSize+0.5f)*boxSize;
dy -= floor(dy/boxSize+0.5f)*boxSize;
dz -= floor(dz/boxSize+0.5f)*boxSize;
}
if (dx*dx + dy*dy + dz*dz > cutoff*cutoff)
shouldInclude = false;
bool isIncluded = (neighbors.find(make_pair(i, j)) != neighbors.end() || neighbors.find(make_pair(j, i)) != neighbors.end());
ASSERT_EQUAL(shouldInclude, isIncluded);
}
}
int main() {
try {
testNeighborList(false);
testNeighborList(true);
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}
...@@ -180,8 +180,9 @@ public: ...@@ -180,8 +180,9 @@ public:
voxelIndex.y = (y+ny)%ny; voxelIndex.y = (y+ny)%ny;
voxelIndex.z = (z+nz)%nz; voxelIndex.z = (z+nz)%nz;
} }
if (voxelMap.find(voxelIndex) == voxelMap.end()) continue; // no such voxel; skip const map<VoxelIndex, Voxel>::const_iterator voxelEntry = voxelMap.find(voxelIndex);
const Voxel& voxel = voxelMap.find(voxelIndex)->second; if (voxelEntry == voxelMap.end()) continue; // no such voxel; skip
const Voxel& voxel = voxelEntry->second;
for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter) for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter)
{ {
const AtomIndex atomJ = itemIter->second; const AtomIndex atomJ = itemIter->second;
...@@ -234,9 +235,9 @@ void OPENMM_EXPORT computeNeighborListVoxelHash( ...@@ -234,9 +235,9 @@ void OPENMM_EXPORT computeNeighborListVoxelHash(
if (!usePeriodic) if (!usePeriodic)
edgeSizeX = edgeSizeY = edgeSizeZ = maxDistance; // TODO - adjust this as needed edgeSizeX = edgeSizeY = edgeSizeZ = maxDistance; // TODO - adjust this as needed
else { else {
edgeSizeX = periodicBoxSize[0]/floor(periodicBoxSize[0]/maxDistance); edgeSizeX = 0.5*periodicBoxSize[0]/floor(periodicBoxSize[0]/maxDistance);
edgeSizeY = periodicBoxSize[1]/floor(periodicBoxSize[1]/maxDistance); edgeSizeY = 0.5*periodicBoxSize[1]/floor(periodicBoxSize[1]/maxDistance);
edgeSizeZ = periodicBoxSize[2]/floor(periodicBoxSize[2]/maxDistance); edgeSizeZ = 0.5*periodicBoxSize[2]/floor(periodicBoxSize[2]/maxDistance);
} }
VoxelHash voxelHash(edgeSizeX, edgeSizeY, edgeSizeZ, periodicBoxSize, usePeriodic); VoxelHash voxelHash(edgeSizeX, edgeSizeY, edgeSizeZ, periodicBoxSize, usePeriodic);
for (AtomIndex atomJ = 0; atomJ < (AtomIndex) nAtoms; ++atomJ) // use "j", because j > i for pairs for (AtomIndex atomJ = 0; atomJ < (AtomIndex) nAtoms; ++atomJ) // use "j", because j > i for pairs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment