Commit 146b66bd authored by peastman's avatar peastman
Browse files

Parallelized the initialization of positions and forces in the CPU platform

parent 543d51d7
...@@ -50,6 +50,7 @@ namespace OpenMM { ...@@ -50,6 +50,7 @@ namespace OpenMM {
*/ */
class CpuCalcForcesAndEnergyKernel : public CalcForcesAndEnergyKernel { class CpuCalcForcesAndEnergyKernel : public CalcForcesAndEnergyKernel {
public: public:
class InitForceTask;
class SumForceTask; class SumForceTask;
CpuCalcForcesAndEnergyKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data, ContextImpl& context); CpuCalcForcesAndEnergyKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data, ContextImpl& context);
/** /**
......
...@@ -132,21 +132,11 @@ public: ...@@ -132,21 +132,11 @@ public:
CpuPlatform::PlatformData& data; CpuPlatform::PlatformData& data;
}; };
CpuCalcForcesAndEnergyKernel::CpuCalcForcesAndEnergyKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data, ContextImpl& context) : class CpuCalcForcesAndEnergyKernel::InitForceTask : public ThreadPool::Task {
CalcForcesAndEnergyKernel(name, platform), data(data) { public:
// Create a Reference platform version of this kernel. InitForceTask(int numParticles, ContextImpl& context, CpuPlatform::PlatformData& data) : numParticles(numParticles), context(context), data(data) {
}
ReferenceKernelFactory referenceFactory; void execute(ThreadPool& threads, int threadIndex) {
referenceKernel = Kernel(referenceFactory.createKernelImpl(name, platform, context));
}
void CpuCalcForcesAndEnergyKernel::initialize(const System& system) {
referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().initialize(system);
}
void CpuCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups) {
referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().beginComputation(context, includeForce, includeEnergy, groups);
// Convert the positions to single precision and apply periodic boundary conditions // Convert the positions to single precision and apply periodic boundary conditions
AlignedArray<float>& posq = data.posq; AlignedArray<float>& posq = data.posq;
...@@ -154,15 +144,18 @@ void CpuCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool i ...@@ -154,15 +144,18 @@ void CpuCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool i
RealVec boxSize = extractBoxSize(context); RealVec boxSize = extractBoxSize(context);
double invBoxSize[3] = {1/boxSize[0], 1/boxSize[1], 1/boxSize[2]}; double invBoxSize[3] = {1/boxSize[0], 1/boxSize[1], 1/boxSize[2]};
int numParticles = context.getSystem().getNumParticles(); int numParticles = context.getSystem().getNumParticles();
int numThreads = threads.getNumThreads();
int start = threadIndex*numParticles/numThreads;
int end = (threadIndex+1)*numParticles/numThreads;
if (data.isPeriodic) if (data.isPeriodic)
for (int i = 0; i < numParticles; i++) for (int i = start; i < end; i++)
for (int j = 0; j < 3; j++) { for (int j = 0; j < 3; j++) {
RealOpenMM x = posData[i][j]; RealOpenMM x = posData[i][j];
double base = floor(x*invBoxSize[j])*boxSize[j]; double base = floor(x*invBoxSize[j])*boxSize[j];
posq[4*i+j] = (float) (x-base); posq[4*i+j] = (float) (x-base);
} }
else else
for (int i = 0; i < numParticles; i++) { for (int i = start; i < end; i++) {
posq[4*i] = (float) posData[i][0]; posq[4*i] = (float) posData[i][0];
posq[4*i+1] = (float) posData[i][1]; posq[4*i+1] = (float) posData[i][1];
posq[4*i+2] = (float) posData[i][2]; posq[4*i+2] = (float) posData[i][2];
...@@ -171,9 +164,34 @@ void CpuCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool i ...@@ -171,9 +164,34 @@ void CpuCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool i
// Clear the forces. // Clear the forces.
fvec4 zero(0.0f); fvec4 zero(0.0f);
for (int i = 0; i < (int) data.threadForce.size(); i++)
for (int j = 0; j < numParticles; j++) for (int j = 0; j < numParticles; j++)
zero.store(&data.threadForce[i][j*4]); zero.store(&data.threadForce[threadIndex][j*4]);
}
int numParticles;
ContextImpl& context;
CpuPlatform::PlatformData& data;
};
CpuCalcForcesAndEnergyKernel::CpuCalcForcesAndEnergyKernel(std::string name, const Platform& platform, CpuPlatform::PlatformData& data, ContextImpl& context) :
CalcForcesAndEnergyKernel(name, platform), data(data) {
// Create a Reference platform version of this kernel.
ReferenceKernelFactory referenceFactory;
referenceKernel = Kernel(referenceFactory.createKernelImpl(name, platform, context));
}
void CpuCalcForcesAndEnergyKernel::initialize(const System& system) {
referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().initialize(system);
}
void CpuCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups) {
referenceKernel.getAs<ReferenceCalcForcesAndEnergyKernel>().beginComputation(context, includeForce, includeEnergy, groups);
// Convert positions to single precision and clear the forces.
InitForceTask task(context.getSystem().getNumParticles(), context, data);
data.threads.execute(task);
data.threads.waitForThreads();
} }
double CpuCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups) { double CpuCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment