Commit 2943d628 authored by peastman's avatar peastman
Browse files

Bug fix

parent dd175b23
...@@ -42,8 +42,8 @@ public: ...@@ -42,8 +42,8 @@ public:
ny = (int) floorf(periodicBoxSize[1]/voxelSizeY+0.5f); ny = (int) floorf(periodicBoxSize[1]/voxelSizeY+0.5f);
nz = (int) floorf(periodicBoxSize[2]/voxelSizeZ+0.5f); nz = (int) floorf(periodicBoxSize[2]/voxelSizeZ+0.5f);
voxelSizeX = periodicBoxSize[0]/nx; voxelSizeX = periodicBoxSize[0]/nx;
voxelSizeY = periodicBoxSize[0]/ny; voxelSizeY = periodicBoxSize[1]/ny;
voxelSizeZ = periodicBoxSize[0]/nz; voxelSizeZ = periodicBoxSize[2]/nz;
} }
} }
......
...@@ -46,12 +46,13 @@ using namespace std; ...@@ -46,12 +46,13 @@ using namespace std;
void testNeighborList(bool periodic) { void testNeighborList(bool periodic) {
const int numParticles = 500; const int numParticles = 500;
const float cutoff = 2.0f; const float cutoff = 2.0f;
const float boxSize = 20.0f; const float boxSize[3] = {20.0f, 15.0f, 22.0f};
OpenMM_SFMT::SFMT sfmt; OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt); init_gen_rand(0, sfmt);
vector<float> positions(4*numParticles); vector<float> positions(4*numParticles);
for (int i = 0; i < 4*numParticles; i++) for (int i = 0; i < 4*numParticles; i++)
positions[i] = boxSize*genrand_real2(sfmt); if (i%4 < 3)
positions[i] = boxSize[i%4]*genrand_real2(sfmt);
vector<set<int> > exclusions(numParticles); vector<set<int> > exclusions(numParticles);
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
int num = min(i+1, 10); int num = min(i+1, 10);
...@@ -61,8 +62,7 @@ void testNeighborList(bool periodic) { ...@@ -61,8 +62,7 @@ void testNeighborList(bool periodic) {
} }
} }
CpuNeighborList neighborList; CpuNeighborList neighborList;
float box[3] = {boxSize, boxSize, boxSize}; neighborList.computeNeighborList(numParticles, positions, exclusions, boxSize, periodic, cutoff);
neighborList.computeNeighborList(numParticles, positions, exclusions, box, periodic, cutoff);
// Convert the neighbor list to a set for faster lookup. // Convert the neighbor list to a set for faster lookup.
...@@ -82,9 +82,9 @@ void testNeighborList(bool periodic) { ...@@ -82,9 +82,9 @@ void testNeighborList(bool periodic) {
float dy = positions[4*i+1]-positions[4*j+1]; float dy = positions[4*i+1]-positions[4*j+1];
float dz = positions[4*i+2]-positions[4*j+2]; float dz = positions[4*i+2]-positions[4*j+2];
if (periodic) { if (periodic) {
dx -= floor(dx/boxSize+0.5f)*boxSize; dx -= floor(dx/boxSize[0]+0.5f)*boxSize[0];
dy -= floor(dy/boxSize+0.5f)*boxSize; dy -= floor(dy/boxSize[1]+0.5f)*boxSize[1];
dz -= floor(dz/boxSize+0.5f)*boxSize; dz -= floor(dz/boxSize[2]+0.5f)*boxSize[2];
} }
if (dx*dx + dy*dy + dz*dz > cutoff*cutoff) if (dx*dx + dy*dy + dz*dz > cutoff*cutoff)
shouldInclude = false; shouldInclude = false;
......
...@@ -101,8 +101,8 @@ public: ...@@ -101,8 +101,8 @@ public:
ny = (int) floor(periodicBoxSize[1]/voxelSizeY+0.5); ny = (int) floor(periodicBoxSize[1]/voxelSizeY+0.5);
nz = (int) floor(periodicBoxSize[2]/voxelSizeZ+0.5); nz = (int) floor(periodicBoxSize[2]/voxelSizeZ+0.5);
voxelSizeX = periodicBoxSize[0]/nx; voxelSizeX = periodicBoxSize[0]/nx;
voxelSizeY = periodicBoxSize[0]/ny; voxelSizeY = periodicBoxSize[1]/ny;
voxelSizeZ = periodicBoxSize[0]/nz; voxelSizeZ = periodicBoxSize[2]/nz;
} }
} }
......
...@@ -86,7 +86,7 @@ void verifyNeighborList(NeighborList& list, int numParticles, vector<RealVec>& p ...@@ -86,7 +86,7 @@ void verifyNeighborList(NeighborList& list, int numParticles, vector<RealVec>& p
for (int j = i+1; j < numParticles; j++) for (int j = i+1; j < numParticles; j++)
if (distance2(positions[i], positions[j], periodicBoxSize) <= cutoff*cutoff) if (distance2(positions[i], positions[j], periodicBoxSize) <= cutoff*cutoff)
count++; count++;
ASSERT(count == list.size()); ASSERT_EQUAL(count, list.size());
} }
void testPeriodic() { void testPeriodic() {
...@@ -112,16 +112,15 @@ void testPeriodic() { ...@@ -112,16 +112,15 @@ void testPeriodic() {
int main() int main()
{ {
try { try {
testNeighborList(); testNeighborList();
testPeriodic(); testPeriodic();
}
cout << "Test Passed" << endl; catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0; return 0;
} }
catch (...) {
cerr << "*** ERROR: Test Failed ***" << endl;
return 1;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment