Commit f631ecaf authored by Peter Eastman's avatar Peter Eastman
Browse files

Sum forces from different contexts on the GPU instead of the CPU

parent 80eb063d
......@@ -129,6 +129,14 @@ public:
void upload(std::vector<T>& data) {
upload(&data[0]);
}
/**
* Copy the values in the Buffer to a vector.
*/
void download(std::vector<T>& data) const {
if (data.size() != size)
data.resize(size);
download(&data[0]);
}
/**
* Copy the values in an array to the Buffer.
*/
......@@ -143,13 +151,11 @@ public:
}
}
/**
* Copy the values in the Buffer to a vector.
* Copy the values in the Buffer to an array.
*/
void download(std::vector<T>& data) const {
if (data.size() != size)
data.resize(size);
void download(T* data) const {
try {
context.getQueue().enqueueReadBuffer(*buffer, CL_TRUE, 0, size*sizeof(T), &data[0]);
context.getQueue().enqueueReadBuffer(*buffer, CL_TRUE, 0, size*sizeof(T), data);
}
catch (cl::Error err) {
std::stringstream str;
......
......@@ -74,15 +74,16 @@ private:
class OpenCLParallelCalcForcesAndEnergyKernel::FinishComputationTask : public OpenCLContext::WorkTask {
public:
FinishComputationTask(ContextImpl& context, OpenCLContext& cl, OpenCLCalcForcesAndEnergyKernel& kernel,
bool includeForce, bool includeEnergy, double& energy, long long& completionTime) : context(context), cl(cl), kernel(kernel),
includeForce(includeForce), includeEnergy(includeEnergy), energy(energy), completionTime(completionTime) {
bool includeForce, bool includeEnergy, double& energy, long long& completionTime, OpenCLArray<mm_float4>& contextForces) :
context(context), cl(cl), kernel(kernel), includeForce(includeForce), includeEnergy(includeEnergy), energy(energy),
completionTime(completionTime), contextForces(contextForces) {
}
void execute() {
// Execute the kernel, then download forces.
energy += kernel.finishComputation(context, includeForce, includeEnergy);
if (includeForce)
cl.getForce().download();
cl.getForce().download(&contextForces[cl.getContextIndex()*cl.getPaddedNumAtoms()]);
completionTime = getTime();
}
private:
......@@ -92,14 +93,20 @@ private:
bool includeForce, includeEnergy;
double& energy;
long long& completionTime;
OpenCLArray<mm_float4>& contextForces;
};
OpenCLParallelCalcForcesAndEnergyKernel::OpenCLParallelCalcForcesAndEnergyKernel(string name, const Platform& platform, OpenCLPlatform::PlatformData& data) :
CalcForcesAndEnergyKernel(name, platform), data(data), completionTimes(data.contexts.size()), contextTiles(data.contexts.size()) {
CalcForcesAndEnergyKernel(name, platform), data(data), completionTimes(data.contexts.size()), contextTiles(data.contexts.size()), contextForces(NULL) {
for (int i = 0; i < (int) data.contexts.size(); i++)
kernels.push_back(Kernel(new OpenCLCalcForcesAndEnergyKernel(name, platform, *data.contexts[i])));
}
OpenCLParallelCalcForcesAndEnergyKernel::~OpenCLParallelCalcForcesAndEnergyKernel() {
if (contextForces != NULL)
delete contextForces;
}
void OpenCLParallelCalcForcesAndEnergyKernel::initialize(const System& system) {
for (int i = 0; i < (int) kernels.size(); i++)
getKernel(i).initialize(system);
......@@ -109,6 +116,11 @@ void OpenCLParallelCalcForcesAndEnergyKernel::beginComputation(ContextImpl& cont
// Copy coordinates over to each device and execute the kernel.
data.contexts[0]->getPosq().download();
if (contextForces == NULL) {
OpenCLContext& cl = *data.contexts[0];
contextForces = new OpenCLArray<mm_float4>(cl, &cl.getForceBuffers().getDeviceBuffer(),
data.contexts.size()*cl.getPaddedNumAtoms(), "contextForces", true);
}
for (int i = 0; i < (int) data.contexts.size(); i++) {
data.contextEnergy[i] = 0.0;
OpenCLContext& cl = *data.contexts[i];
......@@ -121,7 +133,7 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c
for (int i = 0; i < (int) data.contexts.size(); i++) {
OpenCLContext& cl = *data.contexts[i];
OpenCLContext::WorkThread& thread = cl.getWorkThread();
thread.addTask(new FinishComputationTask(context, cl, getKernel(i), includeForce, includeEnergy, data.contextEnergy[i], completionTimes[i]));
thread.addTask(new FinishComputationTask(context, cl, getKernel(i), includeForce, includeEnergy, data.contextEnergy[i], completionTimes[i], *contextForces));
}
data.syncContexts();
double energy = 0.0;
......@@ -130,18 +142,8 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c
if (includeForce) {
// Sum the forces from all devices.
OpenCLArray<mm_float4>& forces = data.contexts[0]->getForce();
for (int i = 1; i < (int) data.contexts.size(); i++) {
OpenCLArray<mm_float4>& contextForces = data.contexts[i]->getForce();
for (int j = 0; j < forces.getSize(); j++) {
mm_float4& f1 = forces[j];
const mm_float4& f2 = contextForces[j];
f1.x += f2.x;
f1.y += f2.y;
f1.z += f2.z;
}
}
forces.upload();
contextForces->upload();
data.contexts[0]->reduceBuffer(*contextForces, data.contexts.size());
// Balance work between the contexts by transferring a few nonbonded tiles from the context that
// finished last to the one that finished first.
......
......@@ -41,6 +41,7 @@ namespace OpenMM {
class OpenCLParallelCalcForcesAndEnergyKernel : public CalcForcesAndEnergyKernel {
public:
OpenCLParallelCalcForcesAndEnergyKernel(std::string name, const Platform& platform, OpenCLPlatform::PlatformData& data);
~OpenCLParallelCalcForcesAndEnergyKernel();
OpenCLCalcForcesAndEnergyKernel& getKernel(int index) {
return dynamic_cast<OpenCLCalcForcesAndEnergyKernel&>(kernels[index].getImpl());
}
......@@ -78,6 +79,7 @@ private:
std::vector<Kernel> kernels;
std::vector<long long> completionTimes;
std::vector<int> contextTiles;
OpenCLArray<mm_float4>* contextForces;
};
/**
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment