Commit 72dc8dfe authored by Peter Eastman's avatar Peter Eastman
Browse files

Optimized communication between GPUs by using pinned memory

parent 68446fd0
...@@ -54,14 +54,14 @@ using namespace std; ...@@ -54,14 +54,14 @@ using namespace std;
class OpenCLParallelCalcForcesAndEnergyKernel::BeginComputationTask : public OpenCLContext::WorkTask { class OpenCLParallelCalcForcesAndEnergyKernel::BeginComputationTask : public OpenCLContext::WorkTask {
public: public:
BeginComputationTask(ContextImpl& context, OpenCLContext& cl, OpenCLCalcForcesAndEnergyKernel& kernel, BeginComputationTask(ContextImpl& context, OpenCLContext& cl, OpenCLCalcForcesAndEnergyKernel& kernel,
bool includeForce, bool includeEnergy) : context(context), cl(cl), kernel(kernel), bool includeForce, bool includeEnergy, mm_float4* pinnedMemory) : context(context), cl(cl), kernel(kernel),
includeForce(includeForce), includeEnergy(includeEnergy) { includeForce(includeForce), includeEnergy(includeEnergy), pinnedMemory(pinnedMemory) {
} }
void execute() { void execute() {
// Copy coordinates over to this device and execute the kernel. // Copy coordinates over to this device and execute the kernel.
if (cl.getContextIndex() > 0) if (cl.getContextIndex() > 0)
cl.getPosq().upload(cl.getPlatformData().contexts[0]->getPosq().getHostBuffer(), false); cl.getQueue().enqueueWriteBuffer(cl.getPosq().getDeviceBuffer(), CL_FALSE, 0, cl.getPaddedNumAtoms()*sizeof(mm_float4), pinnedMemory);
kernel.beginComputation(context, includeForce, includeEnergy); kernel.beginComputation(context, includeForce, includeEnergy);
} }
private: private:
...@@ -69,22 +69,26 @@ private: ...@@ -69,22 +69,26 @@ private:
OpenCLContext& cl; OpenCLContext& cl;
OpenCLCalcForcesAndEnergyKernel& kernel; OpenCLCalcForcesAndEnergyKernel& kernel;
bool includeForce, includeEnergy; bool includeForce, includeEnergy;
mm_float4* pinnedMemory;
}; };
class OpenCLParallelCalcForcesAndEnergyKernel::FinishComputationTask : public OpenCLContext::WorkTask { class OpenCLParallelCalcForcesAndEnergyKernel::FinishComputationTask : public OpenCLContext::WorkTask {
public: public:
FinishComputationTask(ContextImpl& context, OpenCLContext& cl, OpenCLCalcForcesAndEnergyKernel& kernel, FinishComputationTask(ContextImpl& context, OpenCLContext& cl, OpenCLCalcForcesAndEnergyKernel& kernel,
bool includeForce, bool includeEnergy, double& energy, long long& completionTime, OpenCLArray<mm_float4>& contextForces) : bool includeForce, bool includeEnergy, double& energy, long long& completionTime, mm_float4* pinnedMemory) :
context(context), cl(cl), kernel(kernel), includeForce(includeForce), includeEnergy(includeEnergy), energy(energy), context(context), cl(cl), kernel(kernel), includeForce(includeForce), includeEnergy(includeEnergy), energy(energy),
completionTime(completionTime), contextForces(contextForces) { completionTime(completionTime), pinnedMemory(pinnedMemory) {
} }
void execute() { void execute() {
// Execute the kernel, then download forces. // Execute the kernel, then download forces.
energy += kernel.finishComputation(context, includeForce, includeEnergy); energy += kernel.finishComputation(context, includeForce, includeEnergy);
if (includeForce) { if (includeForce) {
if (cl.getContextIndex() > 0) if (cl.getContextIndex() > 0) {
cl.getForce().download(&contextForces[cl.getContextIndex()*cl.getPaddedNumAtoms()]); int numAtoms = cl.getPaddedNumAtoms();
cl.getQueue().enqueueReadBuffer(cl.getForce().getDeviceBuffer(), CL_TRUE, 0,
numAtoms*sizeof(mm_float4), &pinnedMemory[(cl.getContextIndex()-1)*numAtoms]);
}
else else
cl.getQueue().finish(); cl.getQueue().finish();
} }
...@@ -97,11 +101,12 @@ private: ...@@ -97,11 +101,12 @@ private:
bool includeForce, includeEnergy; bool includeForce, includeEnergy;
double& energy; double& energy;
long long& completionTime; long long& completionTime;
OpenCLArray<mm_float4>& contextForces; mm_float4* pinnedMemory;
}; };
OpenCLParallelCalcForcesAndEnergyKernel::OpenCLParallelCalcForcesAndEnergyKernel(string name, const Platform& platform, OpenCLPlatform::PlatformData& data) : OpenCLParallelCalcForcesAndEnergyKernel::OpenCLParallelCalcForcesAndEnergyKernel(string name, const Platform& platform, OpenCLPlatform::PlatformData& data) :
CalcForcesAndEnergyKernel(name, platform), data(data), completionTimes(data.contexts.size()), contextTiles(data.contexts.size()), contextForces(NULL) { CalcForcesAndEnergyKernel(name, platform), data(data), completionTimes(data.contexts.size()), contextTiles(data.contexts.size()), contextForces(NULL),
pinnedBuffer(NULL), pinnedMemory(NULL) {
for (int i = 0; i < (int) data.contexts.size(); i++) for (int i = 0; i < (int) data.contexts.size(); i++)
kernels.push_back(Kernel(new OpenCLCalcForcesAndEnergyKernel(name, platform, *data.contexts[i]))); kernels.push_back(Kernel(new OpenCLCalcForcesAndEnergyKernel(name, platform, *data.contexts[i])));
} }
...@@ -109,6 +114,8 @@ OpenCLParallelCalcForcesAndEnergyKernel::OpenCLParallelCalcForcesAndEnergyKernel ...@@ -109,6 +114,8 @@ OpenCLParallelCalcForcesAndEnergyKernel::OpenCLParallelCalcForcesAndEnergyKernel
OpenCLParallelCalcForcesAndEnergyKernel::~OpenCLParallelCalcForcesAndEnergyKernel() { OpenCLParallelCalcForcesAndEnergyKernel::~OpenCLParallelCalcForcesAndEnergyKernel() {
if (contextForces != NULL) if (contextForces != NULL)
delete contextForces; delete contextForces;
if (pinnedBuffer != NULL)
delete pinnedBuffer;
} }
void OpenCLParallelCalcForcesAndEnergyKernel::initialize(const System& system) { void OpenCLParallelCalcForcesAndEnergyKernel::initialize(const System& system) {
...@@ -117,19 +124,23 @@ void OpenCLParallelCalcForcesAndEnergyKernel::initialize(const System& system) { ...@@ -117,19 +124,23 @@ void OpenCLParallelCalcForcesAndEnergyKernel::initialize(const System& system) {
} }
void OpenCLParallelCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForce, bool includeEnergy) { void OpenCLParallelCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForce, bool includeEnergy) {
// Copy coordinates over to each device and execute the kernel. OpenCLContext& cl0 = *data.contexts[0];
data.contexts[0]->getPosq().download();
if (contextForces == NULL) { if (contextForces == NULL) {
OpenCLContext& cl = *data.contexts[0]; contextForces = new OpenCLArray<mm_float4>(cl0, &cl0.getForceBuffers().getDeviceBuffer(),
contextForces = new OpenCLArray<mm_float4>(cl, &cl.getForceBuffers().getDeviceBuffer(), data.contexts.size()*cl0.getPaddedNumAtoms(), "contextForces", true);
data.contexts.size()*cl.getPaddedNumAtoms(), "contextForces", true); int bufferBytes = (data.contexts.size()-1)*cl0.getPaddedNumAtoms()*sizeof(mm_float4);
pinnedBuffer = new cl::Buffer(cl0.getContext(), CL_MEM_ALLOC_HOST_PTR, bufferBytes);
pinnedMemory = (mm_float4*) cl0.getQueue().enqueueMapBuffer(*pinnedBuffer, CL_TRUE, CL_MAP_READ | CL_MAP_WRITE, 0, bufferBytes);
} }
// Copy coordinates over to each device and execute the kernel.
cl0.getQueue().enqueueReadBuffer(cl0.getPosq().getDeviceBuffer(), CL_TRUE, 0, cl0.getPaddedNumAtoms()*sizeof(mm_float4), pinnedMemory);
for (int i = 0; i < (int) data.contexts.size(); i++) { for (int i = 0; i < (int) data.contexts.size(); i++) {
data.contextEnergy[i] = 0.0; data.contextEnergy[i] = 0.0;
OpenCLContext& cl = *data.contexts[i]; OpenCLContext& cl = *data.contexts[i];
OpenCLContext::WorkThread& thread = cl.getWorkThread(); OpenCLContext::WorkThread& thread = cl.getWorkThread();
thread.addTask(new BeginComputationTask(context, cl, getKernel(i), includeForce, includeEnergy)); thread.addTask(new BeginComputationTask(context, cl, getKernel(i), includeForce, includeEnergy, pinnedMemory));
} }
} }
...@@ -137,7 +148,7 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c ...@@ -137,7 +148,7 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c
for (int i = 0; i < (int) data.contexts.size(); i++) { for (int i = 0; i < (int) data.contexts.size(); i++) {
OpenCLContext& cl = *data.contexts[i]; OpenCLContext& cl = *data.contexts[i];
OpenCLContext::WorkThread& thread = cl.getWorkThread(); OpenCLContext::WorkThread& thread = cl.getWorkThread();
thread.addTask(new FinishComputationTask(context, cl, getKernel(i), includeForce, includeEnergy, data.contextEnergy[i], completionTimes[i], *contextForces)); thread.addTask(new FinishComputationTask(context, cl, getKernel(i), includeForce, includeEnergy, data.contextEnergy[i], completionTimes[i], pinnedMemory));
} }
data.syncContexts(); data.syncContexts();
double energy = 0.0; double energy = 0.0;
...@@ -149,7 +160,7 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c ...@@ -149,7 +160,7 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c
OpenCLContext& cl = *data.contexts[0]; OpenCLContext& cl = *data.contexts[0];
int numAtoms = cl.getPaddedNumAtoms(); int numAtoms = cl.getPaddedNumAtoms();
cl.getQueue().enqueueWriteBuffer(contextForces->getDeviceBuffer(), CL_FALSE, numAtoms*sizeof(mm_float4), cl.getQueue().enqueueWriteBuffer(contextForces->getDeviceBuffer(), CL_FALSE, numAtoms*sizeof(mm_float4),
numAtoms*(data.contexts.size()-1)*sizeof(mm_float4), &(*contextForces)[numAtoms]); numAtoms*(data.contexts.size()-1)*sizeof(mm_float4), pinnedMemory);
cl.reduceBuffer(*contextForces, data.contexts.size()); cl.reduceBuffer(*contextForces, data.contexts.size());
// Balance work between the contexts by transferring a few nonbonded tiles from the context that // Balance work between the contexts by transferring a few nonbonded tiles from the context that
......
...@@ -80,6 +80,8 @@ private: ...@@ -80,6 +80,8 @@ private:
std::vector<long long> completionTimes; std::vector<long long> completionTimes;
std::vector<int> contextTiles; std::vector<int> contextTiles;
OpenCLArray<mm_float4>* contextForces; OpenCLArray<mm_float4>* contextForces;
cl::Buffer* pinnedBuffer;
mm_float4* pinnedMemory;
}; };
/** /**
......
...@@ -106,7 +106,7 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const { ...@@ -106,7 +106,7 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& deviceIndexProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) { OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& deviceIndexProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) {
vector<string> devices; vector<string> devices;
size_t searchPos = 0, nextPos; size_t searchPos = 0, nextPos;
while ((nextPos = deviceIndexProperty.find(',', searchPos)) != string::npos) { while ((nextPos = deviceIndexProperty.find_first_of(", ", searchPos)) != string::npos) {
devices.push_back(deviceIndexProperty.substr(searchPos, nextPos-searchPos)); devices.push_back(deviceIndexProperty.substr(searchPos, nextPos-searchPos));
searchPos = nextPos+1; searchPos = nextPos+1;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment