Commit f631ecaf authored by Peter Eastman's avatar Peter Eastman
Browse files

Sum forces from different contexts on the GPU instead of the CPU

parent 80eb063d
...@@ -129,6 +129,14 @@ public: ...@@ -129,6 +129,14 @@ public:
void upload(std::vector<T>& data) { void upload(std::vector<T>& data) {
upload(&data[0]); upload(&data[0]);
} }
/**
* Copy the values in the Buffer to a vector.
*/
void download(std::vector<T>& data) const {
if (data.size() != size)
data.resize(size);
download(&data[0]);
}
/** /**
* Copy the values in an array to the Buffer. * Copy the values in an array to the Buffer.
*/ */
...@@ -143,13 +151,11 @@ public: ...@@ -143,13 +151,11 @@ public:
} }
} }
/** /**
* Copy the values in the Buffer to a vector. * Copy the values in the Buffer to an array.
*/ */
void download(std::vector<T>& data) const { void download(T* data) const {
if (data.size() != size)
data.resize(size);
try { try {
context.getQueue().enqueueReadBuffer(*buffer, CL_TRUE, 0, size*sizeof(T), &data[0]); context.getQueue().enqueueReadBuffer(*buffer, CL_TRUE, 0, size*sizeof(T), data);
} }
catch (cl::Error err) { catch (cl::Error err) {
std::stringstream str; std::stringstream str;
......
...@@ -74,15 +74,16 @@ private: ...@@ -74,15 +74,16 @@ private:
class OpenCLParallelCalcForcesAndEnergyKernel::FinishComputationTask : public OpenCLContext::WorkTask { class OpenCLParallelCalcForcesAndEnergyKernel::FinishComputationTask : public OpenCLContext::WorkTask {
public: public:
FinishComputationTask(ContextImpl& context, OpenCLContext& cl, OpenCLCalcForcesAndEnergyKernel& kernel, FinishComputationTask(ContextImpl& context, OpenCLContext& cl, OpenCLCalcForcesAndEnergyKernel& kernel,
bool includeForce, bool includeEnergy, double& energy, long long& completionTime) : context(context), cl(cl), kernel(kernel), bool includeForce, bool includeEnergy, double& energy, long long& completionTime, OpenCLArray<mm_float4>& contextForces) :
includeForce(includeForce), includeEnergy(includeEnergy), energy(energy), completionTime(completionTime) { context(context), cl(cl), kernel(kernel), includeForce(includeForce), includeEnergy(includeEnergy), energy(energy),
completionTime(completionTime), contextForces(contextForces) {
} }
void execute() { void execute() {
// Execute the kernel, then download forces. // Execute the kernel, then download forces.
energy += kernel.finishComputation(context, includeForce, includeEnergy); energy += kernel.finishComputation(context, includeForce, includeEnergy);
if (includeForce) if (includeForce)
cl.getForce().download(); cl.getForce().download(&contextForces[cl.getContextIndex()*cl.getPaddedNumAtoms()]);
completionTime = getTime(); completionTime = getTime();
} }
private: private:
...@@ -92,14 +93,20 @@ private: ...@@ -92,14 +93,20 @@ private:
bool includeForce, includeEnergy; bool includeForce, includeEnergy;
double& energy; double& energy;
long long& completionTime; long long& completionTime;
OpenCLArray<mm_float4>& contextForces;
}; };
OpenCLParallelCalcForcesAndEnergyKernel::OpenCLParallelCalcForcesAndEnergyKernel(string name, const Platform& platform, OpenCLPlatform::PlatformData& data) : OpenCLParallelCalcForcesAndEnergyKernel::OpenCLParallelCalcForcesAndEnergyKernel(string name, const Platform& platform, OpenCLPlatform::PlatformData& data) :
CalcForcesAndEnergyKernel(name, platform), data(data), completionTimes(data.contexts.size()), contextTiles(data.contexts.size()) { CalcForcesAndEnergyKernel(name, platform), data(data), completionTimes(data.contexts.size()), contextTiles(data.contexts.size()), contextForces(NULL) {
for (int i = 0; i < (int) data.contexts.size(); i++) for (int i = 0; i < (int) data.contexts.size(); i++)
kernels.push_back(Kernel(new OpenCLCalcForcesAndEnergyKernel(name, platform, *data.contexts[i]))); kernels.push_back(Kernel(new OpenCLCalcForcesAndEnergyKernel(name, platform, *data.contexts[i])));
} }
OpenCLParallelCalcForcesAndEnergyKernel::~OpenCLParallelCalcForcesAndEnergyKernel() {
if (contextForces != NULL)
delete contextForces;
}
void OpenCLParallelCalcForcesAndEnergyKernel::initialize(const System& system) { void OpenCLParallelCalcForcesAndEnergyKernel::initialize(const System& system) {
for (int i = 0; i < (int) kernels.size(); i++) for (int i = 0; i < (int) kernels.size(); i++)
getKernel(i).initialize(system); getKernel(i).initialize(system);
...@@ -109,6 +116,11 @@ void OpenCLParallelCalcForcesAndEnergyKernel::beginComputation(ContextImpl& cont ...@@ -109,6 +116,11 @@ void OpenCLParallelCalcForcesAndEnergyKernel::beginComputation(ContextImpl& cont
// Copy coordinates over to each device and execute the kernel. // Copy coordinates over to each device and execute the kernel.
data.contexts[0]->getPosq().download(); data.contexts[0]->getPosq().download();
if (contextForces == NULL) {
OpenCLContext& cl = *data.contexts[0];
contextForces = new OpenCLArray<mm_float4>(cl, &cl.getForceBuffers().getDeviceBuffer(),
data.contexts.size()*cl.getPaddedNumAtoms(), "contextForces", true);
}
for (int i = 0; i < (int) data.contexts.size(); i++) { for (int i = 0; i < (int) data.contexts.size(); i++) {
data.contextEnergy[i] = 0.0; data.contextEnergy[i] = 0.0;
OpenCLContext& cl = *data.contexts[i]; OpenCLContext& cl = *data.contexts[i];
...@@ -121,7 +133,7 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c ...@@ -121,7 +133,7 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c
for (int i = 0; i < (int) data.contexts.size(); i++) { for (int i = 0; i < (int) data.contexts.size(); i++) {
OpenCLContext& cl = *data.contexts[i]; OpenCLContext& cl = *data.contexts[i];
OpenCLContext::WorkThread& thread = cl.getWorkThread(); OpenCLContext::WorkThread& thread = cl.getWorkThread();
thread.addTask(new FinishComputationTask(context, cl, getKernel(i), includeForce, includeEnergy, data.contextEnergy[i], completionTimes[i])); thread.addTask(new FinishComputationTask(context, cl, getKernel(i), includeForce, includeEnergy, data.contextEnergy[i], completionTimes[i], *contextForces));
} }
data.syncContexts(); data.syncContexts();
double energy = 0.0; double energy = 0.0;
...@@ -130,18 +142,8 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c ...@@ -130,18 +142,8 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c
if (includeForce) { if (includeForce) {
// Sum the forces from all devices. // Sum the forces from all devices.
OpenCLArray<mm_float4>& forces = data.contexts[0]->getForce(); contextForces->upload();
for (int i = 1; i < (int) data.contexts.size(); i++) { data.contexts[0]->reduceBuffer(*contextForces, data.contexts.size());
OpenCLArray<mm_float4>& contextForces = data.contexts[i]->getForce();
for (int j = 0; j < forces.getSize(); j++) {
mm_float4& f1 = forces[j];
const mm_float4& f2 = contextForces[j];
f1.x += f2.x;
f1.y += f2.y;
f1.z += f2.z;
}
}
forces.upload();
// Balance work between the contexts by transferring a few nonbonded tiles from the context that // Balance work between the contexts by transferring a few nonbonded tiles from the context that
// finished last to the one that finished first. // finished last to the one that finished first.
......
...@@ -41,6 +41,7 @@ namespace OpenMM { ...@@ -41,6 +41,7 @@ namespace OpenMM {
class OpenCLParallelCalcForcesAndEnergyKernel : public CalcForcesAndEnergyKernel { class OpenCLParallelCalcForcesAndEnergyKernel : public CalcForcesAndEnergyKernel {
public: public:
OpenCLParallelCalcForcesAndEnergyKernel(std::string name, const Platform& platform, OpenCLPlatform::PlatformData& data); OpenCLParallelCalcForcesAndEnergyKernel(std::string name, const Platform& platform, OpenCLPlatform::PlatformData& data);
~OpenCLParallelCalcForcesAndEnergyKernel();
OpenCLCalcForcesAndEnergyKernel& getKernel(int index) { OpenCLCalcForcesAndEnergyKernel& getKernel(int index) {
return dynamic_cast<OpenCLCalcForcesAndEnergyKernel&>(kernels[index].getImpl()); return dynamic_cast<OpenCLCalcForcesAndEnergyKernel&>(kernels[index].getImpl());
} }
...@@ -78,6 +79,7 @@ private: ...@@ -78,6 +79,7 @@ private:
std::vector<Kernel> kernels; std::vector<Kernel> kernels;
std::vector<long long> completionTimes; std::vector<long long> completionTimes;
std::vector<int> contextTiles; std::vector<int> contextTiles;
OpenCLArray<mm_float4>* contextForces;
}; };
/** /**
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment