Commit 3b3fd578 authored by peastman's avatar peastman
Browse files

More optimizations to neighborlist construction, including SSE intrinsics

parent d4d9ce1f
SET_SOURCE_FILES_PROPERTIES(${SOURCE_FILES} PROPERTIES COMPILE_FLAGS "-msse4.1")
ADD_LIBRARY(${SHARED_TARGET} SHARED ${SOURCE_FILES} ${SOURCE_INCLUDE_FILES} ${API_ABS_INCLUDE_FILES}) ADD_LIBRARY(${SHARED_TARGET} SHARED ${SOURCE_FILES} ${SOURCE_INCLUDE_FILES} ${API_ABS_INCLUDE_FILES})
IF (UNIX AND CMAKE_BUILD_TYPE MATCHES Debug) IF (UNIX AND CMAKE_BUILD_TYPE MATCHES Debug)
......
...@@ -2,33 +2,12 @@ ...@@ -2,33 +2,12 @@
#include <set> #include <set>
#include <map> #include <map>
#include <cmath> #include <cmath>
#include <smmintrin.h>
using namespace std; using namespace std;
namespace OpenMM { namespace OpenMM {
static float periodicDifference(float val1, float val2, float period) {
float diff = val1-val2;
float base = floorf(diff/period+0.5f)*period;
return diff-base;
}
// squared distance between two points
static float compPairDistanceSquared(const float* pos1, const float* pos2, const float* periodicBoxSize, bool usePeriodic) {
float dx, dy, dz;
if (!usePeriodic) {
dx = pos2[0] - pos1[0];
dy = pos2[1] - pos1[1];
dz = pos2[2] - pos1[2];
}
else {
dx = periodicDifference(pos2[0], pos1[0], periodicBoxSize[0]);
dy = periodicDifference(pos2[1], pos1[1], periodicBoxSize[1]);
dz = periodicDifference(pos2[2], pos1[2], periodicBoxSize[2]);
}
return dx*dx + dy*dy + dz*dz;
}
class VoxelIndex class VoxelIndex
{ {
public: public:
...@@ -98,6 +77,10 @@ public: ...@@ -98,6 +77,10 @@ public:
const int atomI = referencePoint.second; const int atomI = referencePoint.second;
const float* locationI = referencePoint.first; const float* locationI = referencePoint.first;
__m128 posI = _mm_loadu_ps(locationI);
__m128 boxSize = _mm_set_ps(0, periodicBoxSize[2], periodicBoxSize[1], periodicBoxSize[0]);
__m128 invBoxSize = _mm_set_ps(0, (1/periodicBoxSize[2]), (1/periodicBoxSize[1]), (1/periodicBoxSize[0]));
__m128 half = _mm_set1_ps(0.5);
float maxDistanceSquared = maxDistance * maxDistance; float maxDistanceSquared = maxDistance * maxDistance;
...@@ -113,31 +96,43 @@ public: ...@@ -113,31 +96,43 @@ public:
lasty = min(lasty, centerVoxelIndex.y-dIndexY+ny-1); lasty = min(lasty, centerVoxelIndex.y-dIndexY+ny-1);
lastz = min(lastz, centerVoxelIndex.z-dIndexZ+nz-1); lastz = min(lastz, centerVoxelIndex.z-dIndexZ+nz-1);
} }
VoxelIndex voxelIndex(0, 0, 0);
for (int x = centerVoxelIndex.x - dIndexX; x <= lastx; ++x) { for (int x = centerVoxelIndex.x - dIndexX; x <= lastx; ++x) {
voxelIndex.x = x;
if (usePeriodic)
voxelIndex.x = (x < 0 ? x+nx : (x >= nx ? x-nx : x));
for (int y = centerVoxelIndex.y - dIndexY; y <= lasty; ++y) { for (int y = centerVoxelIndex.y - dIndexY; y <= lasty; ++y) {
voxelIndex.y = y;
if (usePeriodic)
voxelIndex.y = (y < 0 ? y+ny : (y >= ny ? y-ny : y));
for (int z = centerVoxelIndex.z - dIndexZ; z <= lastz; ++z) { for (int z = centerVoxelIndex.z - dIndexZ; z <= lastz; ++z) {
VoxelIndex voxelIndex(x, y, z); voxelIndex.z = z;
if (usePeriodic) { if (usePeriodic)
voxelIndex.x = (x+nx)%nx; voxelIndex.z = (z < 0 ? z+nz : (z >= nz ? z-nz : z));
voxelIndex.y = (y+ny)%ny;
voxelIndex.z = (z+nz)%nz;
}
const map<VoxelIndex, Voxel>::const_iterator voxelEntry = voxelMap.find(voxelIndex); const map<VoxelIndex, Voxel>::const_iterator voxelEntry = voxelMap.find(voxelIndex);
if (voxelEntry == voxelMap.end()) continue; // no such voxel; skip if (voxelEntry == voxelMap.end())
continue; // no such voxel; skip
const Voxel& voxel = voxelEntry->second; const Voxel& voxel = voxelEntry->second;
for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter) { for (Voxel::const_iterator itemIter = voxel.begin(); itemIter != voxel.end(); ++itemIter) {
const int atomJ = itemIter->second; const int atomJ = itemIter->second;
const float* locationJ = itemIter->first;
// Ignore self hits // Ignore self hits
if (atomI == atomJ) continue; if (atomI == atomJ)
continue;
// Ignore exclusions. // Ignore exclusions.
if (exclusions[atomI].find(atomJ) != exclusions[atomI].end()) continue; if (exclusions[atomI].find(atomJ) != exclusions[atomI].end())
continue;
float dSquared = compPairDistanceSquared(locationI, locationJ, periodicBoxSize, usePeriodic);
if (dSquared > maxDistanceSquared) continue;
__m128 posJ = _mm_loadu_ps(itemIter->first);
__m128 delta = _mm_sub_ps(posJ, posI);
if (usePeriodic) {
__m128 base = _mm_mul_ps(_mm_floor_ps(_mm_add_ps(_mm_mul_ps(delta, invBoxSize), half)), boxSize);
delta = _mm_sub_ps(delta, base);
}
float dSquared = _mm_cvtss_f32(_mm_dp_ps(delta, delta, 0x71));
if (dSquared > maxDistanceSquared)
continue;
neighbors.push_back(make_pair(atomI, atomJ)); neighbors.push_back(make_pair(atomI, atomJ));
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment