Commit f623f01a authored by peastman's avatar peastman
Browse files

Created AVX version of charge spreading. Also fixed incorrect energy computation.

parent dcfd5c02
...@@ -71,7 +71,7 @@ FOREACH(subdir ${OPENMM_SOURCE_SUBDIRS}) ...@@ -71,7 +71,7 @@ FOREACH(subdir ${OPENMM_SOURCE_SUBDIRS})
ENDFOREACH(subdir) ENDFOREACH(subdir)
INCLUDE_DIRECTORIES(BEFORE ${CMAKE_CURRENT_SOURCE_DIR}/src) INCLUDE_DIRECTORIES(BEFORE ${CMAKE_CURRENT_SOURCE_DIR}/src)
SET_SOURCE_FILES_PROPERTIES(${SOURCE_FILES} PROPERTIES COMPILE_FLAGS "-msse4.1") SET_SOURCE_FILES_PROPERTIES(${SOURCE_FILES} PROPERTIES COMPILE_FLAGS "-mavx")
# Include FFTW related files. # Include FFTW related files.
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
#include "../src/SimTKUtilities/SimTKOpenMMRealType.h" #include "../src/SimTKUtilities/SimTKOpenMMRealType.h"
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <smmintrin.h> #include <immintrin.h>
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
...@@ -110,7 +110,7 @@ static void cpuid(int cpuInfo[4], int infoType){ ...@@ -110,7 +110,7 @@ static void cpuid(int cpuInfo[4], int infoType){
} }
#endif #endif
static void spreadCharge(int start, int end, float* posq, float* grid, int gridx, int gridy, int gridz, int numParticles, Vec3 periodicBoxSize) { static void spreadChargeSSE(int start, int end, float* posq, float* grid, int gridx, int gridy, int gridz, int numParticles, Vec3 periodicBoxSize) {
float temp[4]; float temp[4];
__m128 boxSize = _mm_set_ps(0, (float) periodicBoxSize[2], (float) periodicBoxSize[1], (float) periodicBoxSize[0]); __m128 boxSize = _mm_set_ps(0, (float) periodicBoxSize[2], (float) periodicBoxSize[1], (float) periodicBoxSize[0]);
__m128 invBoxSize = _mm_set_ps(0, (float) (1/periodicBoxSize[2]), (float) (1/periodicBoxSize[1]), (float) (1/periodicBoxSize[0])); __m128 invBoxSize = _mm_set_ps(0, (float) (1/periodicBoxSize[2]), (float) (1/periodicBoxSize[1]), (float) (1/periodicBoxSize[0]));
...@@ -193,6 +193,89 @@ static void spreadCharge(int start, int end, float* posq, float* grid, int gridx ...@@ -193,6 +193,89 @@ static void spreadCharge(int start, int end, float* posq, float* grid, int gridx
} }
} }
static void spreadChargeAVX(int start, int end, float* posq, float* grid, int gridx, int gridy, int gridz, int numParticles, Vec3 periodicBoxSize) {
float temp[8];
__m128 boxSize = _mm_set_ps(0, (float) periodicBoxSize[2], (float) periodicBoxSize[1], (float) periodicBoxSize[0]);
__m128 invBoxSize = _mm_set_ps(0, (float) (1/periodicBoxSize[2]), (float) (1/periodicBoxSize[1]), (float) (1/periodicBoxSize[0]));
__m128 gridSize = _mm_set_ps(0, gridz, gridy, gridx);
__m128i gridSizeInt = _mm_set_epi32(0, gridz, gridy, gridx);
__m128 one = _mm_set1_ps(1);
__m128 scale = _mm_set1_ps(1.0f/(PME_ORDER-1));
__m256i mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1);
const float epsilonFactor = sqrt(ONE_4PI_EPS0);
memset(grid, 0, sizeof(float)*gridx*gridy*gridz);
for (int i = start; i < end; i++) {
// Find the position relative to the nearest grid point.
__m128 pos = _mm_load_ps(&posq[4*i]);
__m128 posInBox = _mm_sub_ps(pos, _mm_mul_ps(boxSize, _mm_floor_ps(_mm_mul_ps(pos, invBoxSize))));
__m128 t = _mm_mul_ps(_mm_mul_ps(posInBox, invBoxSize), gridSize);
__m128i ti = _mm_cvttps_epi32(t);
__m128 dr = _mm_sub_ps(t, _mm_cvtepi32_ps(ti));
__m128i gridIndex = _mm_sub_epi32(ti, _mm_and_si128(gridSizeInt, _mm_cmpeq_epi32(ti, gridSizeInt)));
// Compute the B-spline coefficients.
__m128 data[PME_ORDER];
data[PME_ORDER-1] = _mm_setzero_ps();
data[1] = dr;
data[0] = _mm_sub_ps(one, dr);
for (int j = 3; j < PME_ORDER; j++) {
__m128 div = _mm_set1_ps(1.0f/(j-1));
data[j-1] = _mm_mul_ps(_mm_mul_ps(div, dr), data[j-2]);
for (int k = 1; k < j-1; k++)
data[j-k-1] = _mm_mul_ps(div, _mm_add_ps(_mm_mul_ps(_mm_add_ps(dr, _mm_set1_ps(k)), data[j-k-2]), _mm_mul_ps(_mm_sub_ps(_mm_set1_ps(j-k), dr), data[j-k-1])));
data[0] = _mm_mul_ps(_mm_mul_ps(div, _mm_sub_ps(one, dr)), data[0]);
}
data[PME_ORDER-1] = _mm_mul_ps(_mm_mul_ps(scale, dr), data[PME_ORDER-2]);
for (int j = 1; j < (PME_ORDER-1); j++)
data[PME_ORDER-j-1] = _mm_mul_ps(scale, _mm_add_ps(_mm_mul_ps(_mm_add_ps(dr, _mm_set1_ps(j)), data[PME_ORDER-j-2]), _mm_mul_ps(_mm_sub_ps(_mm_set1_ps(PME_ORDER-j), dr), data[PME_ORDER-j-1])));
data[0] = _mm_mul_ps(_mm_mul_ps(scale, _mm_sub_ps(one, dr)), data[0]);
// Spread the charges.
int gridIndexX = _mm_extract_epi32(gridIndex, 0);
int gridIndexY = _mm_extract_epi32(gridIndex, 1);
int gridIndexZ = _mm_extract_epi32(gridIndex, 2);
float charge = epsilonFactor*posq[4*i+3];
__m256 zdata = _mm256_set_ps(0, 0, 0, extractFloat(data[4], 2), extractFloat(data[3], 2), extractFloat(data[2], 2), extractFloat(data[1], 2), extractFloat(data[0], 2));
for (int ix = 0; ix < PME_ORDER; ix++) {
int xbase = gridIndexX+ix;
xbase -= (xbase >= gridx ? gridx : 0);
xbase = xbase*gridy*gridz;
float xdata = extractFloat(data[ix], 0);
for (int iy = 0; iy < PME_ORDER; iy++) {
int ybase = gridIndexY+iy;
ybase -= (ybase >= gridy ? gridy : 0);
ybase = xbase + ybase*gridz;
float multiplier = charge*xdata*extractFloat(data[iy], 1);
__m256 add = _mm256_mul_ps(zdata, _mm256_set1_ps(multiplier));
if (gridIndexZ+5 < gridz)
_mm256_maskstore_ps(&grid[ybase+gridIndexZ], mask, _mm256_add_ps(_mm256_loadu_ps(&grid[ybase+gridIndexZ]), add));
else {
_mm256_store_ps(temp, add);
int zindex = gridIndexZ;
grid[ybase+zindex] += temp[0];
zindex++;
zindex -= (zindex >= gridz ? gridz : 0);
grid[ybase+zindex] += temp[1];
zindex++;
zindex -= (zindex >= gridz ? gridz : 0);
grid[ybase+zindex] += temp[2];
zindex++;
zindex -= (zindex >= gridz ? gridz : 0);
grid[ybase+zindex] += temp[3];
zindex++;
zindex -= (zindex >= gridz ? gridz : 0);
grid[ybase+zindex] += temp[4];
}
}
}
}
}
static void computeReciprocalEterm(int start, int end, int gridx, int gridy, int gridz, vector<float>& recipEterm, double alpha, vector<float>* bsplineModuli, Vec3 periodicBoxSize) { static void computeReciprocalEterm(int start, int end, int gridx, int gridy, int gridz, vector<float>& recipEterm, double alpha, vector<float>* bsplineModuli, Vec3 periodicBoxSize) {
const unsigned int zsize = gridz/2+1; const unsigned int zsize = gridz/2+1;
const unsigned int yzsize = gridy*zsize; const unsigned int yzsize = gridy*zsize;
...@@ -226,19 +309,33 @@ static void computeReciprocalEterm(int start, int end, int gridx, int gridy, int ...@@ -226,19 +309,33 @@ static void computeReciprocalEterm(int start, int end, int gridx, int gridy, int
} }
} }
static float reciprocalEnergy(int start, int end, fftwf_complex* grid, int gridx, int gridy, int gridz, vector<float>& recipEterm) { static float reciprocalEnergy(int start, int end, fftwf_complex* grid, int gridx, int gridy, int gridz, double alpha, vector<float>* bsplineModuli, Vec3 periodicBoxSize) {
const unsigned int zsize = gridz/2+1;
const unsigned int yzsize = gridy*gridz;
const unsigned int zsizeHalf = gridz/2+1; const unsigned int zsizeHalf = gridz/2+1;
const unsigned int yzsizeHalf = gridy*zsizeHalf; const unsigned int yzsizeHalf = gridy*zsizeHalf;
const float scaleFactor = (float) (M_PI*periodicBoxSize[0]*periodicBoxSize[1]*periodicBoxSize[2]);
const float recipExpFactor = (float) (M_PI*M_PI/(alpha*alpha));
const float invPeriodicBoxSizeX = (float) (1.0/periodicBoxSize[0]);
const float invPeriodicBoxSizeY = (float) (1.0/periodicBoxSize[1]);
const float invPeriodicBoxSizeZ = (float) (1.0/periodicBoxSize[2]);
float energy = 0.0f; float energy = 0.0f;
int firstz = (start == 0 ? 1 : 0); int firstz = (start == 0 ? 1 : 0);
for (int kx = start; kx < end; kx++) { for (int kx = start; kx < end; kx++) {
int mx = (kx < (gridx+1)/2) ? kx : kx-gridx;
float mhx = mx*invPeriodicBoxSizeX;
float bx = scaleFactor*bsplineModuli[0][kx];
for (int ky = 0; ky < gridy; ky++) { for (int ky = 0; ky < gridy; ky++) {
int my = (ky < (gridy+1)/2) ? ky : ky-gridy; int my = (ky < (gridy+1)/2) ? ky : ky-gridy;
float mhy = my*invPeriodicBoxSizeY;
float mhx2y2 = mhx*mhx + mhy*mhy;
float bxby = bx*bsplineModuli[1][ky];
for (int kz = firstz; kz < gridz; kz++) { for (int kz = firstz; kz < gridz; kz++) {
float eterm = recipEterm[kx*yzsize + ky*zsize + kz]; int mz = (kz < (gridz+1)/2) ? kz : kz-gridz;
float mhz = mz*invPeriodicBoxSizeZ;
float bz = bsplineModuli[2][kz];
float m2 = mhx2y2 + mhz*mhz;
float denom = m2*bxby*bz;
float eterm = exp(-recipExpFactor*m2)/denom;
int kx1, ky1, kz1; int kx1, ky1, kz1;
if (kz >= gridz/2+1) { if (kz >= gridz/2+1) {
kx1 = (kx == 0 ? kx : gridx-kx); kx1 = (kx == 0 ? kx : gridx-kx);
...@@ -537,6 +634,7 @@ void CpuCalcPmeReciprocalForceKernel::runThread(int index) { ...@@ -537,6 +634,7 @@ void CpuCalcPmeReciprocalForceKernel::runThread(int index) {
else { else {
// This is a worker thread. // This is a worker thread.
bool useAVX = isAVXSupported();
int particleStart = (index*numParticles)/numThreads; int particleStart = (index*numParticles)/numThreads;
int particleEnd = ((index+1)*numParticles)/numThreads; int particleEnd = ((index+1)*numParticles)/numThreads;
int gridxStart = (index*gridx)/numThreads; int gridxStart = (index*gridx)/numThreads;
...@@ -548,7 +646,10 @@ void CpuCalcPmeReciprocalForceKernel::runThread(int index) { ...@@ -548,7 +646,10 @@ void CpuCalcPmeReciprocalForceKernel::runThread(int index) {
threadWait(); threadWait();
if (isDeleted) if (isDeleted)
break; break;
spreadCharge(particleStart, particleEnd, posq, threadData[index]->tempGrid, gridx, gridy, gridz, numParticles, periodicBoxSize); if (useAVX)
spreadChargeSSE(particleStart, particleEnd, posq, threadData[index]->tempGrid, gridx, gridy, gridz, numParticles, periodicBoxSize);
else
spreadChargeAVX(particleStart, particleEnd, posq, threadData[index]->tempGrid, gridx, gridy, gridz, numParticles, periodicBoxSize);
threadWait(); threadWait();
int numGrids = threadData.size(); int numGrids = threadData.size();
for (int i = gridStart; i < gridEnd; i += 4) { for (int i = gridStart; i < gridEnd; i += 4) {
...@@ -563,7 +664,7 @@ void CpuCalcPmeReciprocalForceKernel::runThread(int index) { ...@@ -563,7 +664,7 @@ void CpuCalcPmeReciprocalForceKernel::runThread(int index) {
threadWait(); threadWait();
} }
if (includeEnergy) { if (includeEnergy) {
double threadEnergy = reciprocalEnergy(gridxStart, gridxEnd, complexGrid, gridx, gridy, gridz, recipEterm); double threadEnergy = reciprocalEnergy(gridxStart, gridxEnd, complexGrid, gridx, gridy, gridz, alpha, bsplineModuli, periodicBoxSize);
pthread_mutex_lock(&lock); pthread_mutex_lock(&lock);
energy += threadEnergy; energy += threadEnergy;
pthread_mutex_unlock(&lock); pthread_mutex_unlock(&lock);
...@@ -622,3 +723,13 @@ bool CpuCalcPmeReciprocalForceKernel::isProcessorSupported() { ...@@ -622,3 +723,13 @@ bool CpuCalcPmeReciprocalForceKernel::isProcessorSupported() {
} }
return false; return false;
} }
bool CpuCalcPmeReciprocalForceKernel::isAVXSupported() {
int cpuInfo[4];
cpuid(cpuInfo, 0);
if (cpuInfo[0] >= 1) {
cpuid(cpuInfo, 1);
return ((cpuInfo[2] & ((int) 1 << 28)) != 0);
}
return false;
}
\ No newline at end of file
...@@ -59,6 +59,7 @@ public: ...@@ -59,6 +59,7 @@ public:
private: private:
void threadWait(); void threadWait();
void advanceThreads(); void advanceThreads();
static bool isAVXSupported();
static bool hasInitializedThreads; static bool hasInitializedThreads;
static int numThreads; static int numThreads;
int gridx, gridy, gridz, numParticles; int gridx, gridy, gridz, numParticles;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment