Commit 684ba725 authored by peastman's avatar peastman
Browse files

Parallelized charge spreading for CPU based PME.

parent 58e0996f
......@@ -90,7 +90,7 @@ static int getNumProcessors() {
#endif
}
static void spreadCharge(float* posq, float* grid, int gridx, int gridy, int gridz, int numParticles, Vec3 periodicBoxSize) {
static void spreadCharge(int start, int end, float* posq, float* grid, int gridx, int gridy, int gridz, int numParticles, Vec3 periodicBoxSize) {
float temp[4];
__m128 boxSize = _mm_set_ps(0, (float) periodicBoxSize[2], (float) periodicBoxSize[1], (float) periodicBoxSize[0]);
__m128 invBoxSize = _mm_set_ps(0, (float) (1/periodicBoxSize[2]), (float) (1/periodicBoxSize[1]), (float) (1/periodicBoxSize[0]));
......@@ -100,7 +100,7 @@ static void spreadCharge(float* posq, float* grid, int gridx, int gridy, int gri
__m128 scale = _mm_set1_ps(1.0f/(PME_ORDER-1));
const float epsilonFactor = sqrt(ONE_4PI_EPS0);
memset(grid, 0, sizeof(float)*gridx*gridy*gridz);
for (int i = 0; i < numParticles; i++) {
for (int i = start; i < end; i++) {
// Find the position relative to the nearest grid point.
__m128 pos = _mm_load_ps(&posq[4*i]);
......@@ -297,12 +297,15 @@ static void interpolateForces(int start, int end, float* posq, float* force, flo
for (int j = 1; j < (PME_ORDER-1); j++)
data[PME_ORDER-j-1] = _mm_mul_ps(scale, _mm_add_ps(_mm_mul_ps(_mm_add_ps(dr, _mm_set1_ps(j)), data[PME_ORDER-j-2]), _mm_mul_ps(_mm_sub_ps(_mm_set1_ps(PME_ORDER-j), dr), data[PME_ORDER-j-1])));
data[0] = _mm_mul_ps(_mm_mul_ps(scale, _mm_sub_ps(one, dr)), data[0]);
// Compute the force on this atom.
int gridIndexX = _mm_extract_epi32(gridIndex, 0);
int gridIndexY = _mm_extract_epi32(gridIndex, 1);
int gridIndexZ = _mm_extract_epi32(gridIndex, 2);
__m128 zdata[PME_ORDER];
for (int j = 0; j < PME_ORDER; j++)
zdata[j] = _mm_set_ps(0, extractFloat(ddata[j], 2), extractFloat(data[j], 2), extractFloat(data[j], 2));
__m128 f = _mm_set1_ps(0);
for (int ix = 0; ix < PME_ORDER; ix++) {
int xbase = gridIndexX+ix;
......@@ -324,10 +327,7 @@ static void interpolateForces(int start, int end, float* posq, float* force, flo
int zindex = gridIndexZ+iz;
zindex -= (zindex >= gridz ? gridz : 0);
__m128 gridValue = _mm_set1_ps(grid[ybase+zindex]);
float dz = extractFloat(data[iz], 2);
float ddz = extractFloat(ddata[iz], 2);
__m128 zdata = _mm_set_ps(0, ddz, dz, dz);
f = _mm_add_ps(f, _mm_mul_ps(xydata, _mm_mul_ps(zdata, gridValue)));
f = _mm_add_ps(f, _mm_mul_ps(xydata, _mm_mul_ps(zdata[iz], gridValue)));
}
}
}
......@@ -340,7 +340,8 @@ class CpuPme::ThreadData {
public:
CpuPme& owner;
int index;
ThreadData(CpuPme& owner, int index) : owner(owner), index(index) {
float* tempGrid;
ThreadData(CpuPme& owner, int index) : owner(owner), index(index), tempGrid(NULL) {
}
};
......@@ -368,11 +369,12 @@ CpuPme::CpuPme(int gridx, int gridy, int gridz, int numParticles, double alpha)
ThreadData* data = new ThreadData(*this, i);
threadData.push_back(data);
pthread_create(&thread[i], NULL, threadBody, data);
data->tempGrid = (float*) fftwf_malloc(sizeof(float)*(gridx*gridy*gridz+3));
}
// Initialize FFTW.
realGrid = (float*) fftwf_malloc(sizeof(float)*gridx*gridy*gridz);
realGrid = threadData[0]->tempGrid;
complexGrid = (fftwf_complex*) fftwf_malloc(sizeof(fftwf_complex)*gridx*gridy*(gridz/2+1));
fftwf_plan_with_nthreads(numThreads);
forwardFFT = fftwf_plan_dft_r2c_3d(gridx, gridy, gridz, realGrid, complexGrid, FFTW_MEASURE);
......@@ -445,11 +447,16 @@ CpuPme::~CpuPme() {
isFinished = true;
pthread_cond_broadcast(&startCondition);
pthread_mutex_unlock(&lock);
for (int i = 0; i < thread.size(); i++)
for (int i = 0; i < (int) thread.size(); i++)
pthread_join(thread[i], NULL);
pthread_mutex_destroy(&lock);
pthread_cond_destroy(&startCondition);
pthread_cond_destroy(&endCondition);
for (int i = 0; i < (int) threadData.size(); i++) {
if (threadData[i]->tempGrid != NULL)
fftwf_free(threadData[i]->tempGrid);
delete threadData[i];
}
}
#include <sys/time.h>
......@@ -460,20 +467,33 @@ double diff(struct timeval t1, struct timeval t2) {
void CpuPme::runThread(int index) {
int particleStart = (index*numParticles)/numThreads;
int particleEnd = ((index+1)*numParticles)/numThreads;
int gridStart = (index*gridx)/numThreads;
int gridEnd = ((index+1)*gridx)/numThreads;
int gridxStart = (index*gridx)/numThreads;
int gridxEnd = ((index+1)*gridx)/numThreads;
int gridSize = (gridx*gridy*gridz+3)/4;
int gridStart = 4*((index*gridSize)/numThreads);
int gridEnd = 4*(((index+1)*gridSize)/numThreads);
while (!isFinished) {
threadWait();
spreadCharge(particleStart, particleEnd, posq, threadData[index]->tempGrid, gridx, gridy, gridz, numParticles, periodicBoxSize);
threadWait();
int numGrids = threadData.size();
for (int i = gridStart; i < gridEnd; i += 4) {
__m128 sum = _mm_load_ps(&realGrid[i]);
for (int j = 1; j < numGrids; j++)
sum = _mm_add_ps(sum, _mm_load_ps(&threadData[j]->tempGrid[i]));
_mm_store_ps(&realGrid[i], sum);
}
threadWait();
if (isFinished)
break;
if (includeEnergy) {
double threadEnergy = reciprocalEnergy(gridStart, gridEnd, &complexGrid[0], gridx, gridy, gridz, alpha, bsplineModuli, periodicBoxSize);
double threadEnergy = reciprocalEnergy(gridxStart, gridxEnd, complexGrid, gridx, gridy, gridz, alpha, bsplineModuli, periodicBoxSize);
pthread_mutex_lock(&lock);
energy += threadEnergy;
pthread_mutex_unlock(&lock);
threadWait();
}
reciprocalConvolution(gridStart, gridEnd, complexGrid, gridx, gridy, gridz, alpha, bsplineModuli, periodicBoxSize);
reciprocalConvolution(gridxStart, gridxEnd, complexGrid, gridx, gridy, gridz, alpha, bsplineModuli, periodicBoxSize);
threadWait();
interpolateForces(particleStart, particleEnd, posq, force, realGrid, gridx, gridy, gridz, numParticles, periodicBoxSize);
}
......@@ -505,7 +525,8 @@ double CpuPme::computeForceAndEnergy(float* posq, float* force, Vec3 periodicBox
energy = 0.0;
struct timeval t1, t2, t3, t4, t5, t6, t7;
gettimeofday(&t1, NULL);
spreadCharge(posq, realGrid, gridx, gridy, gridz, numParticles, periodicBoxSize);
advanceThreads(); // Signal threads to perform charge spreading.
advanceThreads(); // Signal threads to sum the charge grids.
gettimeofday(&t2, NULL);
fftwf_execute_dft_r2c(forwardFFT, realGrid, complexGrid);
gettimeofday(&t3, NULL);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment