Commit ff0cdceb authored by peastman's avatar peastman
Browse files

Merge pull request #1082 from peastman/threads

Attempt at fixing a race condition in CPU PME
parents 4c00b312 63e2123d
......@@ -38,6 +38,8 @@ Kernel::Kernel() : impl(0) {
}
Kernel::Kernel(KernelImpl* impl) : impl(impl) {
if (impl)
impl->referenceCount++;
}
Kernel::Kernel(const Kernel& copy) : impl(copy.impl) {
......
......@@ -34,7 +34,7 @@
using namespace OpenMM;
using namespace std;
KernelImpl::KernelImpl(string name, const Platform& platform) : name(name), platform(&platform), referenceCount(1) {
KernelImpl::KernelImpl(string name, const Platform& platform) : name(name), platform(&platform), referenceCount(0) {
}
std::string KernelImpl::getName() const {
......
......@@ -335,21 +335,19 @@ static void interpolateForces(int start, int end, float* posq, float* force, flo
}
}
class CpuCalcPmeReciprocalForceKernel::ThreadData {
class CpuCalcPmeReciprocalForceKernel::ComputeTask : public ThreadPool::Task {
public:
CpuCalcPmeReciprocalForceKernel& owner;
int index;
float* tempGrid;
ThreadData(CpuCalcPmeReciprocalForceKernel& owner, int index) : owner(owner), index(index), tempGrid(NULL) {
ComputeTask(CpuCalcPmeReciprocalForceKernel& owner) : owner(owner) {
}
void execute(ThreadPool& threads, int threadIndex) {
owner.runWorkerThread(threads, threadIndex);
}
CpuCalcPmeReciprocalForceKernel& owner;
};
static void* threadBody(void* args) {
CpuCalcPmeReciprocalForceKernel::ThreadData& data = *reinterpret_cast<CpuCalcPmeReciprocalForceKernel::ThreadData*>(args);
data.owner.runThread(data.index);
if (data.tempGrid != NULL)
fftwf_free(data.tempGrid);
delete &data;
CpuCalcPmeReciprocalForceKernel& owner = *reinterpret_cast<CpuCalcPmeReciprocalForceKernel*>(args);
owner.runMainThread();
return 0;
}
......@@ -362,6 +360,7 @@ void CpuCalcPmeReciprocalForceKernel::initialize(int xsize, int ysize, int zsize
fftwf_init_threads();
hasInitializedThreads = true;
}
threadEnergy.resize(numThreads);
gridx = findFFTDimension(xsize, false);
gridy = findFFTDimension(ysize, false);
gridz = findFFTDimension(zsize, true);
......@@ -372,23 +371,24 @@ void CpuCalcPmeReciprocalForceKernel::initialize(int xsize, int ysize, int zsize
// Initialize threads.
isFinished = false;
pthread_cond_init(&startCondition, NULL);
pthread_cond_init(&endCondition, NULL);
pthread_cond_init(&mainThreadStartCondition, NULL);
pthread_cond_init(&mainThreadEndCondition, NULL);
pthread_mutex_init(&lock, NULL);
thread.resize(numThreads);
for (int i = 0; i < numThreads; i++) {
ThreadData* data = new ThreadData(*this, i);
threadData.push_back(data);
pthread_create(&thread[i], NULL, threadBody, data);
data->tempGrid = (float*) fftwf_malloc(sizeof(float)*(gridx*gridy*gridz+3));
}
pthread_create(&mainThread, NULL, threadBody, new ThreadData(*this, -1));
pthread_create(&mainThread, NULL, threadBody, this);
// Wait until the main thread is up and running.
pthread_mutex_lock(&lock);
while (!isFinished)
pthread_cond_wait(&endCondition, &lock);
pthread_mutex_unlock(&lock);
// Initialize FFTW.
realGrid = threadData[0]->tempGrid;
for (int i = 0; i < numThreads; i++)
tempGrid.push_back((float*) fftwf_malloc(sizeof(float)*(gridx*gridy*gridz+3)));
realGrid = tempGrid[0];
complexGrid = (fftwf_complex*) fftwf_malloc(sizeof(fftwf_complex)*gridx*gridy*(gridz/2+1));
fftwf_plan_with_nthreads(numThreads);
forwardFFT = fftwf_plan_dft_r2c_3d(gridx, gridy, gridz, realGrid, complexGrid, FFTW_MEASURE);
......@@ -455,16 +455,13 @@ CpuCalcPmeReciprocalForceKernel::~CpuCalcPmeReciprocalForceKernel() {
isDeleted = true;
pthread_mutex_lock(&lock);
pthread_cond_broadcast(&startCondition);
pthread_cond_broadcast(&mainThreadStartCondition);
pthread_mutex_unlock(&lock);
for (int i = 0; i < (int) thread.size(); i++)
pthread_join(thread[i], NULL);
pthread_join(mainThread, NULL);
pthread_mutex_destroy(&lock);
pthread_cond_destroy(&startCondition);
pthread_cond_destroy(&endCondition);
pthread_cond_destroy(&mainThreadStartCondition);
pthread_cond_destroy(&mainThreadEndCondition);
for (int i = 0; i < (int) tempGrid.size(); i++)
fftwf_free(tempGrid[i]);
if (complexGrid != NULL)
fftwf_free(complexGrid);
if (hasCreatedPlan) {
......@@ -473,92 +470,79 @@ CpuCalcPmeReciprocalForceKernel::~CpuCalcPmeReciprocalForceKernel() {
}
}
void CpuCalcPmeReciprocalForceKernel::runThread(int index) {
if (index == -1) {
// This is the main thread that coordinates all the other ones.
pthread_mutex_lock(&lock);
while (true) {
// Wait for the signal to start.
pthread_cond_wait(&mainThreadStartCondition, &lock);
if (isDeleted)
break;
posq = io->getPosq();
advanceThreads(); // Signal threads to perform charge spreading.
advanceThreads(); // Signal threads to sum the charge grids.
fftwf_execute_dft_r2c(forwardFFT, realGrid, complexGrid);
if (lastBoxVectors[0] != periodicBoxVectors[0] || lastBoxVectors[1] != periodicBoxVectors[1] || lastBoxVectors[2] != periodicBoxVectors[2])
advanceThreads(); // Signal threads to compute the reciprocal scale factors.
if (includeEnergy)
advanceThreads(); // Signal threads to compute energy.
advanceThreads(); // Signal threads to perform reciprocal convolution.
fftwf_execute_dft_c2r(backwardFFT, complexGrid, realGrid);
advanceThreads(); // Signal threads to interpolate forces.
isFinished = true;
lastBoxVectors[0] = periodicBoxVectors[0];
lastBoxVectors[1] = periodicBoxVectors[1];
lastBoxVectors[2] = periodicBoxVectors[2];
pthread_cond_signal(&mainThreadEndCondition);
}
pthread_mutex_unlock(&lock);
}
else {
// This is a worker thread.
int particleStart = (index*numParticles)/numThreads;
int particleEnd = ((index+1)*numParticles)/numThreads;
int gridxStart = (index*gridx)/numThreads;
int gridxEnd = ((index+1)*gridx)/numThreads;
int gridSize = (gridx*gridy*gridz+3)/4;
int gridStart = 4*((index*gridSize)/numThreads);
int gridEnd = 4*(((index+1)*gridSize)/numThreads);
while (true) {
threadWait();
if (isDeleted)
break;
spreadCharge(particleStart, particleEnd, posq, threadData[index]->tempGrid, gridx, gridy, gridz, numParticles, periodicBoxVectors, recipBoxVectors);
threadWait();
int numGrids = threadData.size();
for (int i = gridStart; i < gridEnd; i += 4) {
fvec4 sum(&realGrid[i]);
for (int j = 1; j < numGrids; j++)
sum += fvec4(&threadData[j]->tempGrid[i]);
sum.store(&realGrid[i]);
}
threadWait();
if (lastBoxVectors[0] != periodicBoxVectors[0] || lastBoxVectors[1] != periodicBoxVectors[1] || lastBoxVectors[2] != periodicBoxVectors[2]) {
computeReciprocalEterm(gridxStart, gridxEnd, gridx, gridy, gridz, recipEterm, alpha, bsplineModuli, periodicBoxVectors, recipBoxVectors);
threadWait();
}
if (includeEnergy) {
double threadEnergy = reciprocalEnergy(gridxStart, gridxEnd, complexGrid, gridx, gridy, gridz, alpha, bsplineModuli, periodicBoxVectors, recipBoxVectors);
pthread_mutex_lock(&lock);
energy += threadEnergy;
pthread_mutex_unlock(&lock);
threadWait();
}
reciprocalConvolution(gridxStart, gridxEnd, complexGrid, gridx, gridy, gridz, recipEterm);
threadWait();
interpolateForces(particleStart, particleEnd, posq, &force[0], realGrid, gridx, gridy, gridz, numParticles, periodicBoxVectors, recipBoxVectors);
}
}
}
void CpuCalcPmeReciprocalForceKernel::runMainThread() {
// This is the main thread that coordinates all the other ones.
void CpuCalcPmeReciprocalForceKernel::threadWait() {
pthread_mutex_lock(&lock);
waitCount++;
isFinished = true;
pthread_cond_signal(&endCondition);
pthread_cond_wait(&startCondition, &lock);
ThreadPool threads(numThreads);
while (true) {
// Wait for the signal to start.
pthread_cond_wait(&startCondition, &lock);
if (isDeleted)
break;
posq = io->getPosq();
ComputeTask task(*this);
threads.execute(task); // Signal threads to perform charge spreading.
threads.waitForThreads();
threads.resumeThreads(); // Signal threads to sum the charge grids.
threads.waitForThreads();
fftwf_execute_dft_r2c(forwardFFT, realGrid, complexGrid);
if (lastBoxVectors[0] != periodicBoxVectors[0] || lastBoxVectors[1] != periodicBoxVectors[1] || lastBoxVectors[2] != periodicBoxVectors[2]) {
threads.resumeThreads(); // Signal threads to compute the reciprocal scale factors.
threads.waitForThreads();
}
if (includeEnergy) {
threads.resumeThreads(); // Signal threads to compute energy.
threads.waitForThreads();
for (int i = 0; i < (int) threadEnergy.size(); i++)
energy += threadEnergy[i];
}
threads.resumeThreads(); // Signal threads to perform reciprocal convolution.
threads.waitForThreads();
fftwf_execute_dft_c2r(backwardFFT, complexGrid, realGrid);
threads.resumeThreads(); // Signal threads to interpolate forces.
threads.waitForThreads();
isFinished = true;
lastBoxVectors[0] = periodicBoxVectors[0];
lastBoxVectors[1] = periodicBoxVectors[1];
lastBoxVectors[2] = periodicBoxVectors[2];
pthread_cond_signal(&endCondition);
}
pthread_mutex_unlock(&lock);
}
void CpuCalcPmeReciprocalForceKernel::advanceThreads() {
waitCount = 0;
pthread_cond_broadcast(&startCondition);
while (waitCount < numThreads) {
pthread_cond_wait(&endCondition, &lock);
void CpuCalcPmeReciprocalForceKernel::runWorkerThread(ThreadPool& threads, int index) {
int particleStart = (index*numParticles)/numThreads;
int particleEnd = ((index+1)*numParticles)/numThreads;
int gridxStart = (index*gridx)/numThreads;
int gridxEnd = ((index+1)*gridx)/numThreads;
int gridSize = (gridx*gridy*gridz+3)/4;
int gridStart = 4*((index*gridSize)/numThreads);
int gridEnd = 4*(((index+1)*gridSize)/numThreads);
spreadCharge(particleStart, particleEnd, posq, tempGrid[index], gridx, gridy, gridz, numParticles, periodicBoxVectors, recipBoxVectors);
threads.syncThreads();
int numGrids = tempGrid.size();
for (int i = gridStart; i < gridEnd; i += 4) {
fvec4 sum(&realGrid[i]);
for (int j = 1; j < numGrids; j++)
sum += fvec4(&tempGrid[j][i]);
sum.store(&realGrid[i]);
}
threads.syncThreads();
if (lastBoxVectors[0] != periodicBoxVectors[0] || lastBoxVectors[1] != periodicBoxVectors[1] || lastBoxVectors[2] != periodicBoxVectors[2]) {
computeReciprocalEterm(gridxStart, gridxEnd, gridx, gridy, gridz, recipEterm, alpha, bsplineModuli, periodicBoxVectors, recipBoxVectors);
threads.syncThreads();
}
if (includeEnergy) {
threadEnergy[index] = reciprocalEnergy(gridxStart, gridxEnd, complexGrid, gridx, gridy, gridz, alpha, bsplineModuli, periodicBoxVectors, recipBoxVectors);
threads.syncThreads();
}
reciprocalConvolution(gridxStart, gridxEnd, complexGrid, gridx, gridy, gridz, recipEterm);
threads.syncThreads();
interpolateForces(particleStart, particleEnd, posq, &force[0], realGrid, gridx, gridy, gridz, numParticles, periodicBoxVectors, recipBoxVectors);
}
void CpuCalcPmeReciprocalForceKernel::beginComputation(IO& io, const Vec3* periodicBoxVectors, bool includeEnergy) {
......@@ -581,14 +565,14 @@ void CpuCalcPmeReciprocalForceKernel::beginComputation(IO& io, const Vec3* perio
pthread_mutex_lock(&lock);
isFinished = false;
pthread_cond_signal(&mainThreadStartCondition);
pthread_cond_signal(&startCondition);
pthread_mutex_unlock(&lock);
}
double CpuCalcPmeReciprocalForceKernel::finishComputation(IO& io) {
pthread_mutex_lock(&lock);
while (!isFinished) {
pthread_cond_wait(&mainThreadEndCondition, &lock);
pthread_cond_wait(&endCondition, &lock);
}
pthread_mutex_unlock(&lock);
io.setForce(&force[0]);
......
......@@ -36,6 +36,7 @@
#include "internal/windowsExportPme.h"
#include "openmm/kernels.h"
#include "openmm/Vec3.h"
#include "openmm/internal/ThreadPool.h"
#include <fftw3.h>
#include <pthread.h>
#include <vector>
......@@ -49,7 +50,6 @@ namespace OpenMM {
class OPENMM_EXPORT_PME CpuCalcPmeReciprocalForceKernel : public CalcPmeReciprocalForceKernel {
public:
class ThreadData;
CpuCalcPmeReciprocalForceKernel(std::string name, const Platform& platform) : CalcPmeReciprocalForceKernel(name, platform),
hasCreatedPlan(false), isDeleted(false), realGrid(NULL), complexGrid(NULL) {
}
......@@ -80,22 +80,19 @@ public:
*/
double finishComputation(IO& io);
/**
* This routine contains the code executed by each thread.
* This routine contains the code executed by the main thread.
*/
void runThread(int index);
void runMainThread();
/**
* This routine contains the code executed by each worker thread.
*/
void runWorkerThread(ThreadPool& threads, int index);
/**
* Get whether the current CPU supports all features needed by this kernel.
*/
static bool isProcessorSupported();
private:
/**
* This is called by the worker threads to wait until the master thread instructs them to advance.
*/
void threadWait();
/**
* This is called by the master thread to instruct all the worker threads to advance.
*/
void advanceThreads();
class ComputeTask;
/**
* Select a size for one grid dimension that FFTW can handle efficiently.
*/
......@@ -109,16 +106,15 @@ private:
std::vector<float> bsplineModuli[3];
std::vector<float> recipEterm;
Vec3 lastBoxVectors[3];
std::vector<float> threadEnergy;
std::vector<float*> tempGrid;
float* realGrid;
fftwf_complex* complexGrid;
fftwf_plan forwardFFT, backwardFFT;
int waitCount;
pthread_cond_t startCondition, endCondition;
pthread_cond_t mainThreadStartCondition, mainThreadEndCondition;
pthread_mutex_t lock;
pthread_t mainThread;
std::vector<pthread_t> thread;
std::vector<ThreadData*> threadData;
// The following variables are used to store information about the calculation currently being performed.
IO* io;
float energy;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment