Commit dcfd5c02 authored by peastman's avatar peastman
Browse files

Optimized reciprocal space convolution by caching scale factors

parent 06fa0896
...@@ -193,16 +193,14 @@ static void spreadCharge(int start, int end, float* posq, float* grid, int gridx ...@@ -193,16 +193,14 @@ static void spreadCharge(int start, int end, float* posq, float* grid, int gridx
} }
} }
static float reciprocalEnergy(int start, int end, fftwf_complex* grid, int gridx, int gridy, int gridz, double alpha, vector<float>* bsplineModuli, Vec3 periodicBoxSize) { static void computeReciprocalEterm(int start, int end, int gridx, int gridy, int gridz, vector<float>& recipEterm, double alpha, vector<float>* bsplineModuli, Vec3 periodicBoxSize) {
const unsigned int yzsize = gridy*gridz; const unsigned int zsize = gridz/2+1;
const unsigned int zsizeHalf = gridz/2+1; const unsigned int yzsize = gridy*zsize;
const unsigned int yzsizeHalf = gridy*zsizeHalf;
const float scaleFactor = (float) (M_PI*periodicBoxSize[0]*periodicBoxSize[1]*periodicBoxSize[2]); const float scaleFactor = (float) (M_PI*periodicBoxSize[0]*periodicBoxSize[1]*periodicBoxSize[2]);
const float recipExpFactor = (float) (M_PI*M_PI/(alpha*alpha)); const float recipExpFactor = (float) (M_PI*M_PI/(alpha*alpha));
const float invPeriodicBoxSizeX = (float) (1.0/periodicBoxSize[0]); const float invPeriodicBoxSizeX = (float) (1.0/periodicBoxSize[0]);
const float invPeriodicBoxSizeY = (float) (1.0/periodicBoxSize[1]); const float invPeriodicBoxSizeY = (float) (1.0/periodicBoxSize[1]);
const float invPeriodicBoxSizeZ = (float) (1.0/periodicBoxSize[2]); const float invPeriodicBoxSizeZ = (float) (1.0/periodicBoxSize[2]);
float energy = 0.0f;
int firstz = (start == 0 ? 1 : 0); int firstz = (start == 0 ? 1 : 0);
for (int kx = start; kx < end; kx++) { for (int kx = start; kx < end; kx++) {
...@@ -214,13 +212,33 @@ static float reciprocalEnergy(int start, int end, fftwf_complex* grid, int gridx ...@@ -214,13 +212,33 @@ static float reciprocalEnergy(int start, int end, fftwf_complex* grid, int gridx
float mhy = my*invPeriodicBoxSizeY; float mhy = my*invPeriodicBoxSizeY;
float mhx2y2 = mhx*mhx + mhy*mhy; float mhx2y2 = mhx*mhx + mhy*mhy;
float bxby = bx*bsplineModuli[1][ky]; float bxby = bx*bsplineModuli[1][ky];
for (int kz = firstz; kz < gridz; kz++) { for (int kz = firstz; kz < zsize; kz++) {
int index = kx*yzsize + ky*zsize + kz;
int mz = (kz < (gridz+1)/2) ? kz : kz-gridz; int mz = (kz < (gridz+1)/2) ? kz : kz-gridz;
float mhz = mz*invPeriodicBoxSizeZ; float mhz = mz*invPeriodicBoxSizeZ;
float bz = bsplineModuli[2][kz]; float bz = bsplineModuli[2][kz];
float m2 = mhx2y2 + mhz*mhz; float m2 = mhx2y2 + mhz*mhz;
float denom = m2*bxby*bz; float denom = m2*bxby*bz;
float eterm = exp(-recipExpFactor*m2)/denom; recipEterm[index] = exp(-recipExpFactor*m2)/denom;
}
firstz = 0;
}
}
}
static float reciprocalEnergy(int start, int end, fftwf_complex* grid, int gridx, int gridy, int gridz, vector<float>& recipEterm) {
const unsigned int zsize = gridz/2+1;
const unsigned int yzsize = gridy*gridz;
const unsigned int zsizeHalf = gridz/2+1;
const unsigned int yzsizeHalf = gridy*zsizeHalf;
float energy = 0.0f;
int firstz = (start == 0 ? 1 : 0);
for (int kx = start; kx < end; kx++) {
for (int ky = 0; ky < gridy; ky++) {
int my = (ky < (gridy+1)/2) ? ky : ky-gridy;
for (int kz = firstz; kz < gridz; kz++) {
float eterm = recipEterm[kx*yzsize + ky*zsize + kz];
int kx1, ky1, kz1; int kx1, ky1, kz1;
if (kz >= gridz/2+1) { if (kz >= gridz/2+1) {
kx1 = (kx == 0 ? kx : gridx-kx); kx1 = (kx == 0 ? kx : gridx-kx);
...@@ -243,33 +261,16 @@ static float reciprocalEnergy(int start, int end, fftwf_complex* grid, int gridx ...@@ -243,33 +261,16 @@ static float reciprocalEnergy(int start, int end, fftwf_complex* grid, int gridx
return 0.5f*energy; return 0.5f*energy;
} }
static void reciprocalConvolution(int start, int end, fftwf_complex* grid, int gridx, int gridy, int gridz, double alpha, vector<float>* bsplineModuli, Vec3 periodicBoxSize) { static void reciprocalConvolution(int start, int end, fftwf_complex* grid, int gridx, int gridy, int gridz, vector<float>& recipEterm) {
const unsigned int zsize = gridz/2+1; const unsigned int zsize = gridz/2+1;
const unsigned int yzsize = gridy*zsize; const unsigned int yzsize = gridy*zsize;
const float scaleFactor = (float) (M_PI*periodicBoxSize[0]*periodicBoxSize[1]*periodicBoxSize[2]);
const float recipExpFactor = (float) (M_PI*M_PI/(alpha*alpha));
const float invPeriodicBoxSizeX = (float) (1.0/periodicBoxSize[0]);
const float invPeriodicBoxSizeY = (float) (1.0/periodicBoxSize[1]);
const float invPeriodicBoxSizeZ = (float) (1.0/periodicBoxSize[2]);
int firstz = (start == 0 ? 1 : 0); int firstz = (start == 0 ? 1 : 0);
for (int kx = start; kx < end; kx++) { for (int kx = start; kx < end; kx++) {
int mx = (kx < (gridx+1)/2) ? kx : kx-gridx;
float mhx = mx*invPeriodicBoxSizeX;
float bx = scaleFactor*bsplineModuli[0][kx];
for (int ky = 0; ky < gridy; ky++) { for (int ky = 0; ky < gridy; ky++) {
int my = (ky < (gridy+1)/2) ? ky : ky-gridy;
float mhy = my*invPeriodicBoxSizeY;
float mhx2y2 = mhx*mhx + mhy*mhy;
float bxby = bx*bsplineModuli[1][ky];
for (int kz = firstz; kz < zsize; kz++) { for (int kz = firstz; kz < zsize; kz++) {
int index = kx*yzsize + ky*zsize + kz; int index = kx*yzsize + ky*zsize + kz;
int mz = (kz < (gridz+1)/2) ? kz : kz-gridz; float eterm = recipEterm[index];
float mhz = mz*invPeriodicBoxSizeZ;
float bz = bsplineModuli[2][kz];
float m2 = mhx2y2 + mhz*mhz;
float denom = m2*bxby*bz;
float eterm = exp(-recipExpFactor*m2)/denom;
grid[index][0] *= eterm; grid[index][0] *= eterm;
grid[index][1] *= eterm; grid[index][1] *= eterm;
} }
...@@ -386,6 +387,7 @@ void CpuCalcPmeReciprocalForceKernel::initialize(int gridx, int gridy, int gridz ...@@ -386,6 +387,7 @@ void CpuCalcPmeReciprocalForceKernel::initialize(int gridx, int gridy, int gridz
this->numParticles = numParticles; this->numParticles = numParticles;
this->alpha = alpha; this->alpha = alpha;
force.resize(4*numParticles); force.resize(4*numParticles);
recipEterm.resize(gridx*gridy*gridz);
// Initialize threads. // Initialize threads.
...@@ -514,6 +516,8 @@ void CpuCalcPmeReciprocalForceKernel::runThread(int index) { ...@@ -514,6 +516,8 @@ void CpuCalcPmeReciprocalForceKernel::runThread(int index) {
gettimeofday(&t2, NULL); gettimeofday(&t2, NULL);
fftwf_execute_dft_r2c(forwardFFT, realGrid, complexGrid); fftwf_execute_dft_r2c(forwardFFT, realGrid, complexGrid);
gettimeofday(&t3, NULL); gettimeofday(&t3, NULL);
if (lastBoxSize != periodicBoxSize)
advanceThreads(); // Signal threads to compute the reciprocal scale factors.
if (includeEnergy) if (includeEnergy)
advanceThreads(); // Signal threads to compute energy. advanceThreads(); // Signal threads to compute energy.
gettimeofday(&t4, NULL); gettimeofday(&t4, NULL);
...@@ -525,6 +529,7 @@ void CpuCalcPmeReciprocalForceKernel::runThread(int index) { ...@@ -525,6 +529,7 @@ void CpuCalcPmeReciprocalForceKernel::runThread(int index) {
isFinished = true; isFinished = true;
gettimeofday(&t7, NULL); gettimeofday(&t7, NULL);
printf("time %g %g %g %g %g %g\n", diff(t1, t2), diff(t2, t3), diff(t3, t4), diff(t4, t5), diff(t5, t6), diff(t6, t7)); printf("time %g %g %g %g %g %g\n", diff(t1, t2), diff(t2, t3), diff(t3, t4), diff(t4, t5), diff(t5, t6), diff(t6, t7));
lastBoxSize = periodicBoxSize;
pthread_cond_signal(&mainThreadEndCondition); pthread_cond_signal(&mainThreadEndCondition);
} }
pthread_mutex_unlock(&lock); pthread_mutex_unlock(&lock);
...@@ -553,14 +558,18 @@ void CpuCalcPmeReciprocalForceKernel::runThread(int index) { ...@@ -553,14 +558,18 @@ void CpuCalcPmeReciprocalForceKernel::runThread(int index) {
_mm_store_ps(&realGrid[i], sum); _mm_store_ps(&realGrid[i], sum);
} }
threadWait(); threadWait();
if (lastBoxSize != periodicBoxSize) {
computeReciprocalEterm(gridxStart, gridxEnd, gridx, gridy, gridz, recipEterm, alpha, bsplineModuli, periodicBoxSize);
threadWait();
}
if (includeEnergy) { if (includeEnergy) {
double threadEnergy = reciprocalEnergy(gridxStart, gridxEnd, complexGrid, gridx, gridy, gridz, alpha, bsplineModuli, periodicBoxSize); double threadEnergy = reciprocalEnergy(gridxStart, gridxEnd, complexGrid, gridx, gridy, gridz, recipEterm);
pthread_mutex_lock(&lock); pthread_mutex_lock(&lock);
energy += threadEnergy; energy += threadEnergy;
pthread_mutex_unlock(&lock); pthread_mutex_unlock(&lock);
threadWait(); threadWait();
} }
reciprocalConvolution(gridxStart, gridxEnd, complexGrid, gridx, gridy, gridz, alpha, bsplineModuli, periodicBoxSize); reciprocalConvolution(gridxStart, gridxEnd, complexGrid, gridx, gridy, gridz, recipEterm);
threadWait(); threadWait();
interpolateForces(particleStart, particleEnd, posq, &force[0], realGrid, gridx, gridy, gridz, numParticles, periodicBoxSize); interpolateForces(particleStart, particleEnd, posq, &force[0], realGrid, gridx, gridy, gridz, numParticles, periodicBoxSize);
} }
......
...@@ -66,6 +66,8 @@ private: ...@@ -66,6 +66,8 @@ private:
bool hasCreatedPlan, isFinished, isDeleted; bool hasCreatedPlan, isFinished, isDeleted;
std::vector<float> force; std::vector<float> force;
std::vector<float> bsplineModuli[3]; std::vector<float> bsplineModuli[3];
std::vector<float> recipEterm;
Vec3 lastBoxSize;
float* realGrid; float* realGrid;
fftwf_complex* complexGrid; fftwf_complex* complexGrid;
fftwf_plan forwardFFT, backwardFFT; fftwf_plan forwardFFT, backwardFFT;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment