Commit 54a93542 authored by peastman's avatar peastman
Browse files

Optimizations and bug fixes to CPU based PME

parent cf1070b6
...@@ -42,6 +42,9 @@ using namespace std; ...@@ -42,6 +42,9 @@ using namespace std;
static const int PME_ORDER = 5; static const int PME_ORDER = 5;
bool CpuPme::hasInitializedThreads = false;
int CpuPme::numThreads = 0;
static float extractFloat(__m128 v, unsigned int element) { static float extractFloat(__m128 v, unsigned int element) {
float f[4]; float f[4];
_mm_store_ps(f, v); _mm_store_ps(f, v);
...@@ -50,8 +53,14 @@ static float extractFloat(__m128 v, unsigned int element) { ...@@ -50,8 +53,14 @@ static float extractFloat(__m128 v, unsigned int element) {
CpuPme::CpuPme(int gridx, int gridy, int gridz, int numParticles, double alpha) : CpuPme::CpuPme(int gridx, int gridy, int gridz, int numParticles, double alpha) :
gridx(gridx), gridy(gridy), gridz(gridz), numParticles(numParticles), alpha(alpha), hasCreatedPlan(false), realGrid(NULL), complexGrid(NULL) { gridx(gridx), gridy(gridy), gridz(gridz), numParticles(numParticles), alpha(alpha), hasCreatedPlan(false), realGrid(NULL), complexGrid(NULL) {
if (!hasInitializedThreads) {
numThreads = 4;
fftwf_init_threads();
hasInitializedThreads = true;
}
realGrid = (float*) fftwf_malloc(sizeof(float)*gridx*gridy*gridz); realGrid = (float*) fftwf_malloc(sizeof(float)*gridx*gridy*gridz);
complexGrid = (fftwf_complex*) fftwf_malloc(sizeof(fftwf_complex)*gridx*gridy*(gridz/2+1)); complexGrid = (fftwf_complex*) fftwf_malloc(sizeof(fftwf_complex)*gridx*gridy*(gridz/2+1));
fftwf_plan_with_nthreads(numThreads);
forwardFFT = fftwf_plan_dft_r2c_3d(gridx, gridy, gridz, realGrid, complexGrid, FFTW_MEASURE); forwardFFT = fftwf_plan_dft_r2c_3d(gridx, gridy, gridz, realGrid, complexGrid, FFTW_MEASURE);
backwardFFT = fftwf_plan_dft_c2r_3d(gridx, gridy, gridz, complexGrid, realGrid, FFTW_MEASURE); backwardFFT = fftwf_plan_dft_c2r_3d(gridx, gridy, gridz, complexGrid, realGrid, FFTW_MEASURE);
hasCreatedPlan = true; hasCreatedPlan = true;
...@@ -225,14 +234,14 @@ static float reciprocalEnergy(fftwf_complex* grid, int gridx, int gridy, int gri ...@@ -225,14 +234,14 @@ static float reciprocalEnergy(fftwf_complex* grid, int gridx, int gridy, int gri
for (int ky = 0; ky < gridy; ky++) { for (int ky = 0; ky < gridy; ky++) {
int my = (ky < (gridy+1)/2) ? ky : ky-gridy; int my = (ky < (gridy+1)/2) ? ky : ky-gridy;
float mhy = my*invPeriodicBoxSizeY; float mhy = my*invPeriodicBoxSizeY;
float by = bsplineModuli[1][ky]; float mhx2y2 = mhx*mhx + mhy*mhy;
float bxby = bx*bsplineModuli[1][ky];
for (int kz = firstz; kz < gridz; kz++) { for (int kz = firstz; kz < gridz; kz++) {
int index = kx*yzsize + ky*gridz + kz;
int mz = (kz < (gridz+1)/2) ? kz : kz-gridz; int mz = (kz < (gridz+1)/2) ? kz : kz-gridz;
float mhz = mz*invPeriodicBoxSizeZ; float mhz = mz*invPeriodicBoxSizeZ;
float bz = bsplineModuli[2][kz]; float bz = bsplineModuli[2][kz];
float m2 = mhx*mhx+mhy*mhy+mhz*mhz; float m2 = mhx2y2 + mhz*mhz;
float denom = m2*bx*by*bz; float denom = m2*bxby*bz;
float eterm = exp(-recipExpFactor*m2)/denom; float eterm = exp(-recipExpFactor*m2)/denom;
int kx1, ky1, kz1; int kx1, ky1, kz1;
if (kz >= gridz/2+1) { if (kz >= gridz/2+1) {
...@@ -245,7 +254,7 @@ static float reciprocalEnergy(fftwf_complex* grid, int gridx, int gridy, int gri ...@@ -245,7 +254,7 @@ static float reciprocalEnergy(fftwf_complex* grid, int gridx, int gridy, int gri
ky1 = ky; ky1 = ky;
kz1 = kz; kz1 = kz;
} }
index = kx1*yzsizeHalf + ky1*zsizeHalf + kz1; int index = kx1*yzsizeHalf + ky1*zsizeHalf + kz1;
float gridReal = grid[index][0]; float gridReal = grid[index][0];
float gridImag = grid[index][1]; float gridImag = grid[index][1];
energy += eterm*(gridReal*gridReal+gridImag*gridImag); energy += eterm*(gridReal*gridReal+gridImag*gridImag);
...@@ -273,17 +282,18 @@ static void reciprocalConvolution(fftwf_complex* grid, int gridx, int gridy, int ...@@ -273,17 +282,18 @@ static void reciprocalConvolution(fftwf_complex* grid, int gridx, int gridy, int
for (int ky = 0; ky < gridy; ky++) { for (int ky = 0; ky < gridy; ky++) {
int my = (ky < (gridy+1)/2) ? ky : ky-gridy; int my = (ky < (gridy+1)/2) ? ky : ky-gridy;
float mhy = my*invPeriodicBoxSizeY; float mhy = my*invPeriodicBoxSizeY;
float by = bsplineModuli[1][ky]; float mhx2y2 = mhx*mhx + mhy*mhy;
float bxby = bx*bsplineModuli[1][ky];
for (int kz = firstz; kz < zsize; kz++) { for (int kz = firstz; kz < zsize; kz++) {
int index = kx*yzsize + ky*zsize + kz; int index = kx*yzsize + ky*zsize + kz;
int mz = (kz < (gridz+1)/2) ? kz : kz-gridz; int mz = (kz < (gridz+1)/2) ? kz : kz-gridz;
float mhz = mz*invPeriodicBoxSizeZ; float mhz = mz*invPeriodicBoxSizeZ;
float bz = bsplineModuli[2][kz]; float bz = bsplineModuli[2][kz];
float m2 = mhx*mhx+mhy*mhy+mhz*mhz; float m2 = mhx2y2 + mhz*mhz;
float denom = m2*bx*by*bz; float denom = m2*bxby*bz;
float eterm = exp(-recipExpFactor*m2)/denom; float eterm = exp(-recipExpFactor*m2)/denom;
grid[index][0] *= eterm;; grid[index][0] *= eterm;
grid[index][1] *= eterm;; grid[index][1] *= eterm;
} }
firstz = 0; firstz = 0;
} }
...@@ -363,7 +373,7 @@ static void interpolateForces(float* posq, float* force, float* grid, int gridx, ...@@ -363,7 +373,7 @@ static void interpolateForces(float* posq, float* force, float* grid, int gridx,
} }
} }
} }
f = _mm_mul_ps(f, _mm_set1_ps(-epsilonFactor*posq[4*i+3])); f = _mm_mul_ps(invBoxSize, _mm_mul_ps(gridSize, _mm_mul_ps(f, _mm_set1_ps(-epsilonFactor*posq[4*i+3]))));
_mm_store_ps(&force[4*i], _mm_add_ps(_mm_load_ps(&force[4*i]), f)); _mm_store_ps(&force[4*i], _mm_add_ps(_mm_load_ps(&force[4*i]), f));
} }
} }
......
...@@ -52,6 +52,8 @@ public: ...@@ -52,6 +52,8 @@ public:
~CpuPme(); ~CpuPme();
double computeForceAndEnergy(float* posq, float* force, Vec3 periodicBoxSize, bool includeEnergy); double computeForceAndEnergy(float* posq, float* force, Vec3 periodicBoxSize, bool includeEnergy);
private: private:
static bool hasInitializedThreads;
static int numThreads;
int gridx, gridy, gridz, numParticles; int gridx, gridy, gridz, numParticles;
double alpha; double alpha;
bool hasCreatedPlan; bool hasCreatedPlan;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment