Commit d4d9ce1f authored by peastman's avatar peastman
Browse files

Optimizations to neighbor list construction

parent d4a343e2
......@@ -15,9 +15,7 @@ public:
const std::vector<std::set<int> >& exclusions,
const float* periodicBoxSize,
bool usePeriodic,
float maxDistance,
float minDistance = 0.0f,
bool reportSymmetricPairs = false);
float maxDistance);
const std::vector<std::pair<int, int> >& getNeighbors();
private:
std::vector<std::pair<int, int> > neighbors;
......
......@@ -155,7 +155,7 @@ double CpuCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeFo
posq[4*i+1] = (float) posData[i][1];
posq[4*i+2] = (float) posData[i][2];
}
neighborList.computeNeighborList(numParticles, posq, exclusions, floatBoxSize, periodic || ewald || pme, nonbondedCutoff, 0.0);
neighborList.computeNeighborList(numParticles, posq, exclusions, floatBoxSize, periodic || ewald || pme, nonbondedCutoff);
// if (nonbondedMethod != NoCutoff) {
// computeNeighborListVoxelHash(*neighborList, numParticles, posData, exclusions, extractBoxSize(context), periodic || ewald || pme, nonbondedCutoff, 0.0);
// clj.setUseCutoff(nonbondedCutoff, *neighborList, rfDielectric);
......
......@@ -32,7 +32,8 @@ static float compPairDistanceSquared(const float* pos1, const float* pos2, const
class VoxelIndex
{
public:
VoxelIndex(int xx, int yy, int zz) : x(xx), y(yy), z(zz) {}
VoxelIndex(int xx, int yy, int zz) : x(xx), y(yy), z(zz) {
}
// operator<() needed for map
bool operator<(const VoxelIndex& other) const {
......@@ -43,17 +44,16 @@ public:
else if (z < other.z) return true;
else return false;
}
int x;
int y;
int z;
};
typedef std::pair<const float*, int> VoxelItem;
typedef std::vector< VoxelItem > Voxel;
typedef pair<const float*, int> VoxelItem;
typedef vector< VoxelItem > Voxel;
class VoxelHash
{
class VoxelHash {
public:
VoxelHash(float vsx, float vsy, float vsz, const float* periodicBoxSize, bool usePeriodic) :
voxelSizeX(vsx), voxelSizeY(vsy), voxelSizeZ(vsz), periodicBoxSize(periodicBoxSize), usePeriodic(usePeriodic) {
......@@ -64,8 +64,7 @@ public:
}
}
void insert(const int& item, const float* location)
{
void insert(const int& item, const float* location) {
VoxelIndex voxelIndex = getVoxelIndex(location);
if (voxelMap.find(voxelIndex) == voxelMap.end()) voxelMap[voxelIndex] = Voxel();
Voxel& voxel = voxelMap.find(voxelIndex)->second;
......@@ -92,14 +91,7 @@ public:
return VoxelIndex(x, y, z);
}
void getNeighbors(
vector<pair<int, int> >& neighbors,
const VoxelItem& referencePoint,
const vector<set<int> >& exclusions,
bool reportSymmetricPairs,
float maxDistance,
float minDistance) const
{
void getNeighbors(vector<pair<int, int> >& neighbors, const VoxelItem& referencePoint, const vector<set<int> >& exclusions, float maxDistance) const {
// Loop over neighboring voxels
// TODO use more clever selection of neighboring voxels
......@@ -108,7 +100,6 @@ public:
const float* locationI = referencePoint.first;
float maxDistanceSquared = maxDistance * maxDistance;
float minDistanceSquared = minDistance * minDistance;
int dIndexX = int(maxDistance / voxelSizeX) + 1; // How may voxels away do we have to look?
int dIndexY = int(maxDistance / voxelSizeY) + 1;
......@@ -122,22 +113,19 @@ public:
lasty = min(lasty, centerVoxelIndex.y-dIndexY+ny-1);
lastz = min(lastz, centerVoxelIndex.z-dIndexZ+nz-1);
}
for (int x = centerVoxelIndex.x - dIndexX; x <= lastx; ++x)
{
for (int y = centerVoxelIndex.y - dIndexY; y <= lasty; ++y)
{
for (int z = centerVoxelIndex.z - dIndexZ; z <= lastz; ++z)
{
for (int x = centerVoxelIndex.x - dIndexX; x <= lastx; ++x) {
for (int y = centerVoxelIndex.y - dIndexY; y <= lasty; ++y) {
for (int z = centerVoxelIndex.z - dIndexZ; z <= lastz; ++z) {
VoxelIndex voxelIndex(x, y, z);
if (usePeriodic) {
voxelIndex.x = (x+nx)%nx;
voxelIndex.y = (y+ny)%ny;
voxelIndex.z = (z+nz)%nz;
}
if (voxelMap.find(voxelIndex) == voxelMap.end()) continue; // no such voxel; skip
const Voxel& voxel = voxelMap.find(voxelIndex)->second;
for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter)
{
const map<VoxelIndex, Voxel>::const_iterator voxelEntry = voxelMap.find(voxelIndex);
if (voxelEntry == voxelMap.end()) continue; // no such voxel; skip
const Voxel& voxel = voxelEntry->second;
for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter) {
const int atomJ = itemIter->second;
const float* locationJ = itemIter->first;
......@@ -149,11 +137,8 @@ public:
float dSquared = compPairDistanceSquared(locationI, locationJ, periodicBoxSize, usePeriodic);
if (dSquared > maxDistanceSquared) continue;
if (dSquared < minDistanceSquared) continue;
neighbors.push_back(make_pair(atomI, atomJ));
if (reportSymmetricPairs)
neighbors.push_back(make_pair(atomJ, atomI));
}
}
}
......@@ -165,43 +150,32 @@ private:
int nx, ny, nz;
const float* periodicBoxSize;
const bool usePeriodic;
std::map<VoxelIndex, Voxel> voxelMap;
map<VoxelIndex, Voxel> voxelMap;
};
// O(n) neighbor list method using voxel hash data structure
void CpuNeighborList::computeNeighborList(
int nAtoms,
const vector<float>& atomLocations,
const vector<set<int> >& exclusions,
const float* periodicBoxSize,
bool usePeriodic,
float maxDistance,
float minDistance,
bool reportSymmetricPairs)
{
void CpuNeighborList::computeNeighborList(int nAtoms, const vector<float>& atomLocations, const vector<set<int> >& exclusions,
const float* periodicBoxSize, bool usePeriodic, float maxDistance) {
neighbors.clear();
float edgeSizeX, edgeSizeY, edgeSizeZ;
if (!usePeriodic)
edgeSizeX = edgeSizeY = edgeSizeZ = maxDistance; // TODO - adjust this as needed
else {
edgeSizeX = periodicBoxSize[0]/floorf(periodicBoxSize[0]/maxDistance);
edgeSizeY = periodicBoxSize[1]/floorf(periodicBoxSize[1]/maxDistance);
edgeSizeZ = periodicBoxSize[2]/floorf(periodicBoxSize[2]/maxDistance);
edgeSizeX = 0.5f*periodicBoxSize[0]/floorf(periodicBoxSize[0]/maxDistance);
edgeSizeY = 0.5f*periodicBoxSize[1]/floorf(periodicBoxSize[1]/maxDistance);
edgeSizeZ = 0.5f*periodicBoxSize[2]/floorf(periodicBoxSize[2]/maxDistance);
}
VoxelHash voxelHash(edgeSizeX, edgeSizeY, edgeSizeZ, periodicBoxSize, usePeriodic);
for (int atomJ = 0; atomJ < (int) nAtoms; ++atomJ) // use "j", because j > i for pairs
{
for (int atomJ = 0; atomJ < (int) nAtoms; ++atomJ) { // use "j", because j > i for pairs
// 1) Find other atoms that are close to this one
const float location[3] = {atomLocations[4*atomJ], atomLocations[4*atomJ+1], atomLocations[4*atomJ+2]};
const float* location = &atomLocations[4*atomJ];
voxelHash.getNeighbors(
neighbors,
VoxelItem(location, atomJ),
exclusions,
reportSymmetricPairs,
maxDistance,
minDistance);
maxDistance);
// 2) Add this atom to the voxelHash
voxelHash.insert(atomJ, location);
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
/**
* This tests all the CPU implementation of neighbor list construction.
*/
#include "openmm/internal/AssertionUtilities.h"
#include "CpuNeighborList.h"
#include "sfmt/SFMT.h"
#include <iostream>
#include <set>
#include <vector>
using namespace OpenMM;
using namespace std;
void testNeighborList(bool periodic) {
const int numParticles = 500;
const float cutoff = 2.0f;
const float boxSize = 20.0f;
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
vector<float> positions(4*numParticles);
for (int i = 0; i < 4*numParticles; i++)
positions[i] = boxSize*genrand_real2(sfmt);
vector<set<int> > exclusions(numParticles);
for (int i = 0; i < numParticles; i++) {
int num = min(i+1, 10);
for (int j = 0; j < num; j++) {
exclusions[i].insert(i-j);
exclusions[i-j].insert(i);
}
}
CpuNeighborList neighborList;
float box[3] = {boxSize, boxSize, boxSize};
neighborList.computeNeighborList(numParticles, positions, exclusions, box, periodic, cutoff);
// Convert the neighbor list to a set for faster lookup.
set<pair<int, int> > neighbors;
for (int i = 0; i < (int) neighborList.getNeighbors().size(); i++) {
pair<int, int> entry = neighborList.getNeighbors()[i];
ASSERT(neighbors.find(entry) == neighbors.end() && neighbors.find(make_pair(entry.second, entry.first)) == neighbors.end()); // No duplicates
neighbors.insert(entry);
}
// Check each particle pair and figure out whether they should be in the neighbor list.
for (int i = 0; i < numParticles; i++)
for (int j = 0; j <= i; j++) {
bool shouldInclude = (exclusions[i].find(j) == exclusions[i].end());
float dx = positions[4*i]-positions[4*j];
float dy = positions[4*i+1]-positions[4*j+1];
float dz = positions[4*i+2]-positions[4*j+2];
if (periodic) {
dx -= floor(dx/boxSize+0.5f)*boxSize;
dy -= floor(dy/boxSize+0.5f)*boxSize;
dz -= floor(dz/boxSize+0.5f)*boxSize;
}
if (dx*dx + dy*dy + dz*dz > cutoff*cutoff)
shouldInclude = false;
bool isIncluded = (neighbors.find(make_pair(i, j)) != neighbors.end() || neighbors.find(make_pair(j, i)) != neighbors.end());
ASSERT_EQUAL(shouldInclude, isIncluded);
}
}
int main() {
try {
testNeighborList(false);
testNeighborList(true);
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}
......@@ -180,8 +180,9 @@ public:
voxelIndex.y = (y+ny)%ny;
voxelIndex.z = (z+nz)%nz;
}
if (voxelMap.find(voxelIndex) == voxelMap.end()) continue; // no such voxel; skip
const Voxel& voxel = voxelMap.find(voxelIndex)->second;
const map<VoxelIndex, Voxel>::const_iterator voxelEntry = voxelMap.find(voxelIndex);
if (voxelEntry == voxelMap.end()) continue; // no such voxel; skip
const Voxel& voxel = voxelEntry->second;
for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter)
{
const AtomIndex atomJ = itemIter->second;
......@@ -234,9 +235,9 @@ void OPENMM_EXPORT computeNeighborListVoxelHash(
if (!usePeriodic)
edgeSizeX = edgeSizeY = edgeSizeZ = maxDistance; // TODO - adjust this as needed
else {
edgeSizeX = periodicBoxSize[0]/floor(periodicBoxSize[0]/maxDistance);
edgeSizeY = periodicBoxSize[1]/floor(periodicBoxSize[1]/maxDistance);
edgeSizeZ = periodicBoxSize[2]/floor(periodicBoxSize[2]/maxDistance);
edgeSizeX = 0.5*periodicBoxSize[0]/floor(periodicBoxSize[0]/maxDistance);
edgeSizeY = 0.5*periodicBoxSize[1]/floor(periodicBoxSize[1]/maxDistance);
edgeSizeZ = 0.5*periodicBoxSize[2]/floor(periodicBoxSize[2]/maxDistance);
}
VoxelHash voxelHash(edgeSizeX, edgeSizeY, edgeSizeZ, periodicBoxSize, usePeriodic);
for (AtomIndex atomJ = 0; atomJ < (AtomIndex) nAtoms; ++atomJ) // use "j", because j > i for pairs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment