Commit fcdba25c authored by Peter Eastman's avatar Peter Eastman
Browse files

Fixed bugs in checkpointing

parent 91788719
...@@ -72,7 +72,7 @@ public: ...@@ -72,7 +72,7 @@ public:
class OPENMM_EXPORT OpenCLPlatform::PlatformData { class OPENMM_EXPORT OpenCLPlatform::PlatformData {
public: public:
PlatformData(int numParticles, const std::string& platformPropValue, const std::string& deviceIndexProperty); PlatformData(const System& system, const std::string& platformPropValue, const std::string& deviceIndexProperty);
~PlatformData(); ~PlatformData();
void initializeContexts(const System& system); void initializeContexts(const System& system);
void syncContexts(); void syncContexts();
......
...@@ -65,7 +65,7 @@ static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_i ...@@ -65,7 +65,7 @@ static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_i
std::cerr << "OpenCL internal error: " << errinfo << std::endl; std::cerr << "OpenCL internal error: " << errinfo << std::endl;
} }
OpenCLContext::OpenCLContext(int numParticles, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData) : OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData) :
time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), atomsWereReordered(false), posq(NULL), time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), atomsWereReordered(false), posq(NULL),
velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), atomIndex(NULL), integration(NULL), velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), atomIndex(NULL), integration(NULL),
bonded(NULL), nonbonded(NULL), thread(NULL) { bonded(NULL), nonbonded(NULL), thread(NULL) {
...@@ -207,8 +207,8 @@ OpenCLContext::OpenCLContext(int numParticles, int platformIndex, int deviceInde ...@@ -207,8 +207,8 @@ OpenCLContext::OpenCLContext(int numParticles, int platformIndex, int deviceInde
cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[platformIndex](), 0}; cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[platformIndex](), 0};
context = cl::Context(contextDevices, cprops, errorCallback); context = cl::Context(contextDevices, cprops, errorCallback);
queue = cl::CommandQueue(context, device); queue = cl::CommandQueue(context, device);
numAtoms = numParticles; numAtoms = system.getNumParticles();
paddedNumAtoms = TileSize*((numParticles+TileSize-1)/TileSize); paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize);
numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize; numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize;
numThreadBlocks = numThreadBlocksPerComputeUnit*device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>(); numThreadBlocks = numThreadBlocksPerComputeUnit*device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
bonded = new OpenCLBondedUtilities(*this); bonded = new OpenCLBondedUtilities(*this);
...@@ -268,6 +268,10 @@ OpenCLContext::OpenCLContext(int numParticles, int platformIndex, int deviceInde ...@@ -268,6 +268,10 @@ OpenCLContext::OpenCLContext(int numParticles, int platformIndex, int deviceInde
// Create the work thread used for parallelization when running on multiple devices. // Create the work thread used for parallelization when running on multiple devices.
thread = new WorkThread(); thread = new WorkThread();
// Create the integration utilities object.
integration = new OpenCLIntegrationUtilities(*this, system);
} }
OpenCLContext::~OpenCLContext() { OpenCLContext::~OpenCLContext() {
...@@ -328,7 +332,6 @@ void OpenCLContext::initialize(const System& system) { ...@@ -328,7 +332,6 @@ void OpenCLContext::initialize(const System& system) {
(*atomIndex)[i] = i; (*atomIndex)[i] = i;
atomIndex->upload(); atomIndex->upload();
findMoleculeGroups(system); findMoleculeGroups(system);
integration = new OpenCLIntegrationUtilities(*this, system);
nonbonded->initialize(system); nonbonded->initialize(system);
} }
......
...@@ -146,7 +146,7 @@ public: ...@@ -146,7 +146,7 @@ public:
class ReorderListener; class ReorderListener;
static const int ThreadBlockSize; static const int ThreadBlockSize;
static const int TileSize; static const int TileSize;
OpenCLContext(int numParticles, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData); OpenCLContext(const System& system, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData);
~OpenCLContext(); ~OpenCLContext();
/** /**
* This is called to initialize internal data structures after all Forces in the system * This is called to initialize internal data structures after all Forces in the system
...@@ -473,6 +473,12 @@ public: ...@@ -473,6 +473,12 @@ public:
* assumes ownership of the object, and deletes it when the context itself is deleted. * assumes ownership of the object, and deletes it when the context itself is deleted.
*/ */
void addReorderListener(ReorderListener* listener); void addReorderListener(ReorderListener* listener);
/**
* Get the list of ReorderListeners.
*/
std::vector<ReorderListener*>& getReorderListeners() {
return reorderListeners;
}
private: private:
struct Molecule; struct Molecule;
struct MoleculeGroup; struct MoleculeGroup;
......
...@@ -592,7 +592,7 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c ...@@ -592,7 +592,7 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
vsitePositionKernel.setArg<cl::Buffer>(6, vsiteOutOfPlaneWeights->getDeviceBuffer()); vsitePositionKernel.setArg<cl::Buffer>(6, vsiteOutOfPlaneWeights->getDeviceBuffer());
vsiteForceKernel = cl::Kernel(ccmaProgram, "distributeForces"); vsiteForceKernel = cl::Kernel(ccmaProgram, "distributeForces");
vsiteForceKernel.setArg<cl::Buffer>(0, context.getPosq().getDeviceBuffer()); vsiteForceKernel.setArg<cl::Buffer>(0, context.getPosq().getDeviceBuffer());
vsiteForceKernel.setArg<cl::Buffer>(1, context.getForce().getDeviceBuffer()); // Skip argument 1: the force array hasn't been created yet.
vsiteForceKernel.setArg<cl::Buffer>(2, vsite2AvgAtoms->getDeviceBuffer()); vsiteForceKernel.setArg<cl::Buffer>(2, vsite2AvgAtoms->getDeviceBuffer());
vsiteForceKernel.setArg<cl::Buffer>(3, vsite2AvgWeights->getDeviceBuffer()); vsiteForceKernel.setArg<cl::Buffer>(3, vsite2AvgWeights->getDeviceBuffer());
vsiteForceKernel.setArg<cl::Buffer>(4, vsite3AvgAtoms->getDeviceBuffer()); vsiteForceKernel.setArg<cl::Buffer>(4, vsite3AvgAtoms->getDeviceBuffer());
...@@ -763,8 +763,10 @@ void OpenCLIntegrationUtilities::computeVirtualSites() { ...@@ -763,8 +763,10 @@ void OpenCLIntegrationUtilities::computeVirtualSites() {
} }
void OpenCLIntegrationUtilities::distributeForcesFromVirtualSites() { void OpenCLIntegrationUtilities::distributeForcesFromVirtualSites() {
if (numVsites > 0) if (numVsites > 0) {
vsiteForceKernel.setArg<cl::Buffer>(1, context.getForce().getDeviceBuffer());
context.executeKernel(vsiteForceKernel, numVsites); context.executeKernel(vsiteForceKernel, numVsites);
}
} }
void OpenCLIntegrationUtilities::initRandomNumberGenerator(unsigned int randomNumberSeed) { void OpenCLIntegrationUtilities::initRandomNumberGenerator(unsigned int randomNumberSeed) {
......
...@@ -231,6 +231,8 @@ void OpenCLUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream ...@@ -231,6 +231,8 @@ void OpenCLUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream
stream.write((char*) &cl.getPosq()[0], sizeof(mm_float4)*cl.getPosq().getSize()); stream.write((char*) &cl.getPosq()[0], sizeof(mm_float4)*cl.getPosq().getSize());
cl.getVelm().download(); cl.getVelm().download();
stream.write((char*) &cl.getVelm()[0], sizeof(mm_float4)*cl.getVelm().getSize()); stream.write((char*) &cl.getVelm()[0], sizeof(mm_float4)*cl.getVelm().getSize());
stream.write((char*) &cl.getAtomIndex()[0], sizeof(cl_int)*cl.getAtomIndex().getSize());
stream.write((char*) &cl.getPosCellOffsets()[0], sizeof(mm_int4)*cl.getPosCellOffsets().size());
mm_float4 box = cl.getPeriodicBoxSize(); mm_float4 box = cl.getPeriodicBoxSize();
stream.write((char*) &box, sizeof(mm_float4)); stream.write((char*) &box, sizeof(mm_float4));
cl.getIntegrationUtilities().createCheckpoint(stream); cl.getIntegrationUtilities().createCheckpoint(stream);
...@@ -244,16 +246,24 @@ void OpenCLUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& ...@@ -244,16 +246,24 @@ void OpenCLUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream&
throw OpenMMException("Checkpoint was created with a different version of OpenMM"); throw OpenMMException("Checkpoint was created with a different version of OpenMM");
double time; double time;
stream.read((char*) &time, sizeof(double)); stream.read((char*) &time, sizeof(double));
cl.setTime(time); vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
for (int i = 0; i < (int) contexts.size(); i++)
contexts[i]->setTime(time);
stream.read((char*) &cl.getPosq()[0], sizeof(mm_float4)*cl.getPosq().getSize()); stream.read((char*) &cl.getPosq()[0], sizeof(mm_float4)*cl.getPosq().getSize());
cl.getPosq().upload(); cl.getPosq().upload();
stream.read((char*) &cl.getVelm()[0], sizeof(mm_float4)*cl.getVelm().getSize()); stream.read((char*) &cl.getVelm()[0], sizeof(mm_float4)*cl.getVelm().getSize());
cl.getVelm().upload(); cl.getVelm().upload();
stream.read((char*) &cl.getAtomIndex()[0], sizeof(cl_int)*cl.getAtomIndex().getSize());
cl.getAtomIndex().upload();
stream.read((char*) &cl.getPosCellOffsets()[0], sizeof(mm_int4)*cl.getPosCellOffsets().size());
mm_float4 box; mm_float4 box;
stream.read((char*) &box, sizeof(mm_float4)); stream.read((char*) &box, sizeof(mm_float4));
cl.setPeriodicBoxSize(box.x, box.y, box.z); for (int i = 0; i < (int) contexts.size(); i++)
contexts[i]->setPeriodicBoxSize(box.x, box.y, box.z);
cl.getIntegrationUtilities().loadCheckpoint(stream); cl.getIntegrationUtilities().loadCheckpoint(stream);
SimTKOpenMMUtilities::loadCheckpoint(stream); SimTKOpenMMUtilities::loadCheckpoint(stream);
for (int i = 0; i < cl.getReorderListeners().size(); i++)
cl.getReorderListeners()[i]->execute();
} }
void OpenCLApplyConstraintsKernel::initialize(const System& system) { void OpenCLApplyConstraintsKernel::initialize(const System& system) {
...@@ -4435,6 +4445,7 @@ void OpenCLApplyAndersenThermostatKernel::initialize(const System& system, const ...@@ -4435,6 +4445,7 @@ void OpenCLApplyAndersenThermostatKernel::initialize(const System& system, const
defines["NUM_ATOMS"] = intToString(cl.getNumAtoms()); defines["NUM_ATOMS"] = intToString(cl.getNumAtoms());
cl::Program program = cl.createProgram(OpenCLKernelSources::andersenThermostat, defines); cl::Program program = cl.createProgram(OpenCLKernelSources::andersenThermostat, defines);
kernel = cl::Kernel(program, "applyAndersenThermostat"); kernel = cl::Kernel(program, "applyAndersenThermostat");
cl.getIntegrationUtilities().initRandomNumberGenerator(randomSeed);
// Create the arrays with the group definitions. // Create the arrays with the group definitions.
...@@ -4451,7 +4462,6 @@ void OpenCLApplyAndersenThermostatKernel::initialize(const System& system, const ...@@ -4451,7 +4462,6 @@ void OpenCLApplyAndersenThermostatKernel::initialize(const System& system, const
void OpenCLApplyAndersenThermostatKernel::execute(ContextImpl& context) { void OpenCLApplyAndersenThermostatKernel::execute(ContextImpl& context) {
if (!hasInitializedKernels) { if (!hasInitializedKernels) {
hasInitializedKernels = true; hasInitializedKernels = true;
cl.getIntegrationUtilities().initRandomNumberGenerator(randomSeed);
kernel.setArg<cl::Buffer>(2, cl.getVelm().getDeviceBuffer()); kernel.setArg<cl::Buffer>(2, cl.getVelm().getDeviceBuffer());
kernel.setArg<cl::Buffer>(3, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer()); kernel.setArg<cl::Buffer>(3, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
kernel.setArg<cl::Buffer>(4, cl.getIntegrationUtilities().getRandom().getDeviceBuffer()); kernel.setArg<cl::Buffer>(4, cl.getIntegrationUtilities().getRandom().getDeviceBuffer());
......
...@@ -101,8 +101,7 @@ void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, stri ...@@ -101,8 +101,7 @@ void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, stri
getPropertyDefaultValue(OpenCLPlatformIndex()) : properties.find(OpenCLPlatformIndex())->second); getPropertyDefaultValue(OpenCLPlatformIndex()) : properties.find(OpenCLPlatformIndex())->second);
const string& devicePropValue = (properties.find(OpenCLDeviceIndex()) == properties.end() ? const string& devicePropValue = (properties.find(OpenCLDeviceIndex()) == properties.end() ?
getPropertyDefaultValue(OpenCLDeviceIndex()) : properties.find(OpenCLDeviceIndex())->second); getPropertyDefaultValue(OpenCLDeviceIndex()) : properties.find(OpenCLDeviceIndex())->second);
int numParticles = context.getSystem().getNumParticles(); context.setPlatformData(new PlatformData(context.getSystem(), platformPropValue, devicePropValue));
context.setPlatformData(new PlatformData(numParticles, platformPropValue, devicePropValue));
} }
void OpenCLPlatform::contextDestroyed(ContextImpl& context) const { void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
...@@ -110,7 +109,7 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const { ...@@ -110,7 +109,7 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
delete data; delete data;
} }
OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& platformPropValue, const string& deviceIndexProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) { OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& platformPropValue, const string& deviceIndexProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) {
int platformIndex = 0; int platformIndex = 0;
if (platformPropValue.length() > 0) if (platformPropValue.length() > 0)
stringstream(platformPropValue) >> platformIndex; stringstream(platformPropValue) >> platformIndex;
...@@ -125,11 +124,11 @@ OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& platf ...@@ -125,11 +124,11 @@ OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& platf
if (devices[i].length() > 0) { if (devices[i].length() > 0) {
unsigned int deviceIndex; unsigned int deviceIndex;
stringstream(devices[i]) >> deviceIndex; stringstream(devices[i]) >> deviceIndex;
contexts.push_back(new OpenCLContext(numParticles, platformIndex, deviceIndex, *this)); contexts.push_back(new OpenCLContext(system, platformIndex, deviceIndex, *this));
} }
} }
if (contexts.size() == 0) if (contexts.size() == 0)
contexts.push_back(new OpenCLContext(numParticles, platformIndex, -1, *this)); contexts.push_back(new OpenCLContext(system, platformIndex, -1, *this));
stringstream device; stringstream device;
for (int i = 0; i < (int) contexts.size(); i++) { for (int i = 0; i < (int) contexts.size(); i++) {
if (i > 0) if (i > 0)
......
...@@ -68,8 +68,8 @@ void compareStates(State& s1, State& s2) { ...@@ -68,8 +68,8 @@ void compareStates(State& s1, State& s2) {
} }
void testCheckpoint() { void testCheckpoint() {
const int numParticles = 10; const int numParticles = 100;
const double boxSize = 3.0; const double boxSize = 5.0;
const double temperature = 200.0; const double temperature = 200.0;
OpenCLPlatform platform; OpenCLPlatform platform;
System system; System system;
...@@ -83,7 +83,16 @@ void testCheckpoint() { ...@@ -83,7 +83,16 @@ void testCheckpoint() {
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0); system.addParticle(1.0);
nonbonded->addParticle(i%2 == 0 ? 0.1 : -0.1, 0.2, 0.1); nonbonded->addParticle(i%2 == 0 ? 0.1 : -0.1, 0.2, 0.1);
positions[i] = Vec3(boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt)); bool clash;
do {
clash = false;
positions[i] = Vec3(boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt));
for (int j = 0; j < i; j++) {
Vec3 delta = positions[i]-positions[j];
if (sqrt(delta.dot(delta)) < 0.1)
clash = true;
}
} while (clash);
} }
VerletIntegrator integrator(0.001); VerletIntegrator integrator(0.001);
Context context(system, integrator, platform); Context context(system, integrator, platform);
...@@ -119,6 +128,34 @@ void testCheckpoint() { ...@@ -119,6 +128,34 @@ void testCheckpoint() {
integrator.step(10); integrator.step(10);
State s4 = context.getState(State::Positions | State::Velocities | State::Parameters); State s4 = context.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s2, s4); compareStates(s2, s4);
// Create a new Context that uses multiple devices.
string deviceIndex = platform.getPropertyValue(context, OpenCLPlatform::OpenCLDeviceIndex());
map<string, string> props;
props[OpenCLPlatform::OpenCLDeviceIndex()] = deviceIndex+","+deviceIndex;
VerletIntegrator integrator2(0.001);
Context context2(system, integrator2, platform, props);
context2.setPositions(positions);
context2.setPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
context2.setParameter(AndersenThermostat::Temperature(), temperature);
// Now repeat all of the above tests with it.
integrator2.step(100);
State s5 = context2.getState(State::Positions | State::Velocities | State::Parameters);
stringstream stream2(ios_base::out | ios_base::in | ios_base::binary);
context2.createCheckpoint(stream2);
integrator2.step(10);
State s6 = context2.getState(State::Positions | State::Velocities | State::Parameters);
context2.setPeriodicBoxVectors(Vec3(2*boxSize, 0, 0), Vec3(0, 2*boxSize, 0), Vec3(0, 0, 2*boxSize));
context2.setParameter(AndersenThermostat::Temperature(), temperature+10);
context2.loadCheckpoint(stream2);
State s7 = context2.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s5, s7);
integrator2.step(10);
State s8 = context2.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s6, s8);
} }
int main() { int main() {
......
...@@ -51,7 +51,7 @@ using namespace std; ...@@ -51,7 +51,7 @@ using namespace std;
void testTransform() { void testTransform() {
System system; System system;
system.addParticle(0.0); system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(1, "", ""); OpenCLPlatform::PlatformData platformData(system, "", "");
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(system); context.initialize(system);
OpenMM_SFMT::SFMT sfmt; OpenMM_SFMT::SFMT sfmt;
......
...@@ -48,7 +48,7 @@ void testGaussian() { ...@@ -48,7 +48,7 @@ void testGaussian() {
System system; System system;
for (int i = 0; i < numAtoms; i++) for (int i = 0; i < numAtoms; i++)
system.addParticle(1.0); system.addParticle(1.0);
OpenCLPlatform::PlatformData platformData(numAtoms, "", ""); OpenCLPlatform::PlatformData platformData(system, "", "");
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(system); context.initialize(system);
context.getIntegrationUtilities().initRandomNumberGenerator(0); context.getIntegrationUtilities().initRandomNumberGenerator(0);
......
...@@ -62,7 +62,7 @@ void verifySorting(vector<float> array) { ...@@ -62,7 +62,7 @@ void verifySorting(vector<float> array) {
System system; System system;
system.addParticle(0.0); system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(1, "", ""); OpenCLPlatform::PlatformData platformData(system, "", "");
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(system); context.initialize(system);
OpenCLArray<float> data(context, array.size(), "sortData"); OpenCLArray<float> data(context, array.size(), "sortData");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment