"openmmapi/src/ForceImpl.cpp" did not exist on "4e1e1b1162036a617d6ecadd5e218a799e0ecc6b"
Commit f631ecaf authored by Peter Eastman's avatar Peter Eastman
Browse files

Sum forces from different contexts on the GPU instead of the CPU

parent 80eb063d
...@@ -129,6 +129,14 @@ public: ...@@ -129,6 +129,14 @@ public:
void upload(std::vector<T>& data) { void upload(std::vector<T>& data) {
upload(&data[0]); upload(&data[0]);
} }
/**
* Copy the values in the Buffer to a vector.
*/
void download(std::vector<T>& data) const {
if (data.size() != size)
data.resize(size);
download(&data[0]);
}
/** /**
* Copy the values in an array to the Buffer. * Copy the values in an array to the Buffer.
*/ */
...@@ -143,13 +151,11 @@ public: ...@@ -143,13 +151,11 @@ public:
} }
} }
/** /**
* Copy the values in the Buffer to a vector. * Copy the values in the Buffer to an array.
*/ */
void download(std::vector<T>& data) const { void download(T* data) const {
if (data.size() != size)
data.resize(size);
try { try {
context.getQueue().enqueueReadBuffer(*buffer, CL_TRUE, 0, size*sizeof(T), &data[0]); context.getQueue().enqueueReadBuffer(*buffer, CL_TRUE, 0, size*sizeof(T), data);
} }
catch (cl::Error err) { catch (cl::Error err) {
std::stringstream str; std::stringstream str;
......
...@@ -74,15 +74,16 @@ private: ...@@ -74,15 +74,16 @@ private:
class OpenCLParallelCalcForcesAndEnergyKernel::FinishComputationTask : public OpenCLContext::WorkTask { class OpenCLParallelCalcForcesAndEnergyKernel::FinishComputationTask : public OpenCLContext::WorkTask {
public: public:
FinishComputationTask(ContextImpl& context, OpenCLContext& cl, OpenCLCalcForcesAndEnergyKernel& kernel, FinishComputationTask(ContextImpl& context, OpenCLContext& cl, OpenCLCalcForcesAndEnergyKernel& kernel,
bool includeForce, bool includeEnergy, double& energy, long long& completionTime) : context(context), cl(cl), kernel(kernel), bool includeForce, bool includeEnergy, double& energy, long long& completionTime, OpenCLArray<mm_float4>& contextForces) :
includeForce(includeForce), includeEnergy(includeEnergy), energy(energy), completionTime(completionTime) { context(context), cl(cl), kernel(kernel), includeForce(includeForce), includeEnergy(includeEnergy), energy(energy),
completionTime(completionTime), contextForces(contextForces) {
} }
void execute() { void execute() {
// Execute the kernel, then download forces. // Execute the kernel, then download forces.
energy += kernel.finishComputation(context, includeForce, includeEnergy); energy += kernel.finishComputation(context, includeForce, includeEnergy);
if (includeForce) if (includeForce)
cl.getForce().download(); cl.getForce().download(&contextForces[cl.getContextIndex()*cl.getPaddedNumAtoms()]);
completionTime = getTime(); completionTime = getTime();
} }
private: private:
...@@ -92,14 +93,20 @@ private: ...@@ -92,14 +93,20 @@ private:
bool includeForce, includeEnergy; bool includeForce, includeEnergy;
double& energy; double& energy;
long long& completionTime; long long& completionTime;
OpenCLArray<mm_float4>& contextForces;
}; };
OpenCLParallelCalcForcesAndEnergyKernel::OpenCLParallelCalcForcesAndEnergyKernel(string name, const Platform& platform, OpenCLPlatform::PlatformData& data) : OpenCLParallelCalcForcesAndEnergyKernel::OpenCLParallelCalcForcesAndEnergyKernel(string name, const Platform& platform, OpenCLPlatform::PlatformData& data) :
CalcForcesAndEnergyKernel(name, platform), data(data), completionTimes(data.contexts.size()), contextTiles(data.contexts.size()) { CalcForcesAndEnergyKernel(name, platform), data(data), completionTimes(data.contexts.size()), contextTiles(data.contexts.size()), contextForces(NULL) {
for (int i = 0; i < (int) data.contexts.size(); i++) for (int i = 0; i < (int) data.contexts.size(); i++)
kernels.push_back(Kernel(new OpenCLCalcForcesAndEnergyKernel(name, platform, *data.contexts[i]))); kernels.push_back(Kernel(new OpenCLCalcForcesAndEnergyKernel(name, platform, *data.contexts[i])));
} }
OpenCLParallelCalcForcesAndEnergyKernel::~OpenCLParallelCalcForcesAndEnergyKernel() {
if (contextForces != NULL)
delete contextForces;
}
void OpenCLParallelCalcForcesAndEnergyKernel::initialize(const System& system) { void OpenCLParallelCalcForcesAndEnergyKernel::initialize(const System& system) {
for (int i = 0; i < (int) kernels.size(); i++) for (int i = 0; i < (int) kernels.size(); i++)
getKernel(i).initialize(system); getKernel(i).initialize(system);
...@@ -109,6 +116,11 @@ void OpenCLParallelCalcForcesAndEnergyKernel::beginComputation(ContextImpl& cont ...@@ -109,6 +116,11 @@ void OpenCLParallelCalcForcesAndEnergyKernel::beginComputation(ContextImpl& cont
// Copy coordinates over to each device and execute the kernel. // Copy coordinates over to each device and execute the kernel.
data.contexts[0]->getPosq().download(); data.contexts[0]->getPosq().download();
if (contextForces == NULL) {
OpenCLContext& cl = *data.contexts[0];
contextForces = new OpenCLArray<mm_float4>(cl, &cl.getForceBuffers().getDeviceBuffer(),
data.contexts.size()*cl.getPaddedNumAtoms(), "contextForces", true);
}
for (int i = 0; i < (int) data.contexts.size(); i++) { for (int i = 0; i < (int) data.contexts.size(); i++) {
data.contextEnergy[i] = 0.0; data.contextEnergy[i] = 0.0;
OpenCLContext& cl = *data.contexts[i]; OpenCLContext& cl = *data.contexts[i];
...@@ -121,7 +133,7 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c ...@@ -121,7 +133,7 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c
for (int i = 0; i < (int) data.contexts.size(); i++) { for (int i = 0; i < (int) data.contexts.size(); i++) {
OpenCLContext& cl = *data.contexts[i]; OpenCLContext& cl = *data.contexts[i];
OpenCLContext::WorkThread& thread = cl.getWorkThread(); OpenCLContext::WorkThread& thread = cl.getWorkThread();
thread.addTask(new FinishComputationTask(context, cl, getKernel(i), includeForce, includeEnergy, data.contextEnergy[i], completionTimes[i])); thread.addTask(new FinishComputationTask(context, cl, getKernel(i), includeForce, includeEnergy, data.contextEnergy[i], completionTimes[i], *contextForces));
} }
data.syncContexts(); data.syncContexts();
double energy = 0.0; double energy = 0.0;
...@@ -130,18 +142,8 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c ...@@ -130,18 +142,8 @@ double OpenCLParallelCalcForcesAndEnergyKernel::finishComputation(ContextImpl& c
if (includeForce) { if (includeForce) {
// Sum the forces from all devices. // Sum the forces from all devices.
OpenCLArray<mm_float4>& forces = data.contexts[0]->getForce(); contextForces->upload();
for (int i = 1; i < (int) data.contexts.size(); i++) { data.contexts[0]->reduceBuffer(*contextForces, data.contexts.size());
OpenCLArray<mm_float4>& contextForces = data.contexts[i]->getForce();
for (int j = 0; j < forces.getSize(); j++) {
mm_float4& f1 = forces[j];
const mm_float4& f2 = contextForces[j];
f1.x += f2.x;
f1.y += f2.y;
f1.z += f2.z;
}
}
forces.upload();
// Balance work between the contexts by transferring a few nonbonded tiles from the context that // Balance work between the contexts by transferring a few nonbonded tiles from the context that
// finished last to the one that finished first. // finished last to the one that finished first.
......
...@@ -41,6 +41,7 @@ namespace OpenMM { ...@@ -41,6 +41,7 @@ namespace OpenMM {
class OpenCLParallelCalcForcesAndEnergyKernel : public CalcForcesAndEnergyKernel { class OpenCLParallelCalcForcesAndEnergyKernel : public CalcForcesAndEnergyKernel {
public: public:
OpenCLParallelCalcForcesAndEnergyKernel(std::string name, const Platform& platform, OpenCLPlatform::PlatformData& data); OpenCLParallelCalcForcesAndEnergyKernel(std::string name, const Platform& platform, OpenCLPlatform::PlatformData& data);
~OpenCLParallelCalcForcesAndEnergyKernel();
OpenCLCalcForcesAndEnergyKernel& getKernel(int index) { OpenCLCalcForcesAndEnergyKernel& getKernel(int index) {
return dynamic_cast<OpenCLCalcForcesAndEnergyKernel&>(kernels[index].getImpl()); return dynamic_cast<OpenCLCalcForcesAndEnergyKernel&>(kernels[index].getImpl());
} }
...@@ -78,6 +79,7 @@ private: ...@@ -78,6 +79,7 @@ private:
std::vector<Kernel> kernels; std::vector<Kernel> kernels;
std::vector<long long> completionTimes; std::vector<long long> completionTimes;
std::vector<int> contextTiles; std::vector<int> contextTiles;
OpenCLArray<mm_float4>* contextForces;
}; };
/** /**
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment