Unverified Commit b28d2e66 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Merged parallel code (#4649)

* Unified lots of parallel computation code between platforms

* Unified test code between platforms

* Eliminated duplicated timing code
parent 78902bed
......@@ -658,6 +658,10 @@ vector<ComputeContext*> OpenCLContext::getAllContexts() {
return result;
}
double& OpenCLContext::getEnergyWorkspace() {
return platformData.contextEnergy[contextIndex];
}
cl::CommandQueue& OpenCLContext::getQueue() {
return currentQueue;
}
......
......@@ -27,6 +27,7 @@
#include "OpenCLKernelFactory.h"
#include "OpenCLParallelKernels.h"
#include "openmm/common/CommonKernels.h"
#include "openmm/common/CommonParallelKernels.h"
#include "openmm/internal/ContextImpl.h"
#include "openmm/OpenMMException.h"
......@@ -34,39 +35,39 @@ using namespace OpenMM;
KernelImpl* OpenCLKernelFactory::createKernelImpl(std::string name, const Platform& platform, ContextImpl& context) const {
OpenCLPlatform::PlatformData& data = *static_cast<OpenCLPlatform::PlatformData*>(context.getPlatformData());
OpenCLContext& cl = *data.contexts[0];
if (data.contexts.size() > 1) {
// We are running in parallel on multiple devices, so we may want to create a parallel kernel.
if (name == CalcForcesAndEnergyKernel::Name())
return new OpenCLParallelCalcForcesAndEnergyKernel(name, platform, data);
if (name == CalcHarmonicBondForceKernel::Name())
return new OpenCLParallelCalcHarmonicBondForceKernel(name, platform, data, context.getSystem());
return new CommonParallelCalcHarmonicBondForceKernel(name, platform, cl, context.getSystem());
if (name == CalcCustomBondForceKernel::Name())
return new OpenCLParallelCalcCustomBondForceKernel(name, platform, data, context.getSystem());
return new CommonParallelCalcCustomBondForceKernel(name, platform, cl, context.getSystem());
if (name == CalcHarmonicAngleForceKernel::Name())
return new OpenCLParallelCalcHarmonicAngleForceKernel(name, platform, data, context.getSystem());
return new CommonParallelCalcHarmonicAngleForceKernel(name, platform, cl, context.getSystem());
if (name == CalcCustomAngleForceKernel::Name())
return new OpenCLParallelCalcCustomAngleForceKernel(name, platform, data, context.getSystem());
return new CommonParallelCalcCustomAngleForceKernel(name, platform, cl, context.getSystem());
if (name == CalcPeriodicTorsionForceKernel::Name())
return new OpenCLParallelCalcPeriodicTorsionForceKernel(name, platform, data, context.getSystem());
return new CommonParallelCalcPeriodicTorsionForceKernel(name, platform, cl, context.getSystem());
if (name == CalcRBTorsionForceKernel::Name())
return new OpenCLParallelCalcRBTorsionForceKernel(name, platform, data, context.getSystem());
return new CommonParallelCalcRBTorsionForceKernel(name, platform, cl, context.getSystem());
if (name == CalcCMAPTorsionForceKernel::Name())
return new OpenCLParallelCalcCMAPTorsionForceKernel(name, platform, data, context.getSystem());
return new CommonParallelCalcCMAPTorsionForceKernel(name, platform, cl, context.getSystem());
if (name == CalcCustomTorsionForceKernel::Name())
return new OpenCLParallelCalcCustomTorsionForceKernel(name, platform, data, context.getSystem());
return new CommonParallelCalcCustomTorsionForceKernel(name, platform, cl, context.getSystem());
if (name == CalcNonbondedForceKernel::Name())
return new OpenCLParallelCalcNonbondedForceKernel(name, platform, data, context.getSystem());
if (name == CalcCustomNonbondedForceKernel::Name())
return new OpenCLParallelCalcCustomNonbondedForceKernel(name, platform, data, context.getSystem());
return new CommonParallelCalcCustomNonbondedForceKernel(name, platform, cl, context.getSystem());
if (name == CalcCustomExternalForceKernel::Name())
return new OpenCLParallelCalcCustomExternalForceKernel(name, platform, data, context.getSystem());
return new CommonParallelCalcCustomExternalForceKernel(name, platform, cl, context.getSystem());
if (name == CalcCustomHbondForceKernel::Name())
return new OpenCLParallelCalcCustomHbondForceKernel(name, platform, data, context.getSystem());
return new CommonParallelCalcCustomHbondForceKernel(name, platform, cl, context.getSystem());
if (name == CalcCustomCompoundBondForceKernel::Name())
return new OpenCLParallelCalcCustomCompoundBondForceKernel(name, platform, data, context.getSystem());
return new CommonParallelCalcCustomCompoundBondForceKernel(name, platform, cl, context.getSystem());
}
OpenCLContext& cl = *data.contexts[0];
if (name == CalcForcesAndEnergyKernel::Name())
return new OpenCLCalcForcesAndEnergyKernel(name, platform, cl);
if (name == UpdateStateDataKernel::Name())
......
......@@ -32,38 +32,6 @@
#include "OpenCLTests.h"
#include "TestRBTorsionForce.h"
void testParallelComputation() {
System system;
const int numParticles = 200;
for (int i = 0; i < numParticles; i++)
system.addParticle(1.0);
RBTorsionForce* force = new RBTorsionForce();
for (int i = 3; i < numParticles; i++)
force->addTorsion(i-3, i-2, i-1, i, 2, 0.1*i, 0.5*i, i, 1, 1);
system.addForce(force);
vector<Vec3> positions(numParticles);
for (int i = 0; i < numParticles; i++)
positions[i] = Vec3(i, i%2, i%3);
VerletIntegrator integrator1(0.01);
Context context1(system, integrator1, platform);
context1.setPositions(positions);
State state1 = context1.getState(State::Forces | State::Energy);
VerletIntegrator integrator2(0.01);
map<string, string> props;
string deviceIndex = platform.getPropertyValue(context1, OpenCLPlatform::OpenCLDeviceIndex());
props[OpenCLPlatform::OpenCLDeviceIndex()] = deviceIndex+","+deviceIndex;
string platformIndex = platform.getPropertyValue(context1, OpenCLPlatform::OpenCLPlatformIndex());
props[OpenCLPlatform::OpenCLPlatformIndex()] = platformIndex;
Context context2(system, integrator2, platform, props);
context2.setPositions(positions);
State state2 = context2.getState(State::Forces | State::Energy);
ASSERT_EQUAL_TOL(state1.getPotentialEnergy(), state2.getPotentialEnergy(), 1e-5);
for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5);
}
void runPlatformTests() {
testParallelComputation();
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment