Commit fcdba25c authored by Peter Eastman's avatar Peter Eastman
Browse files

Fixed bugs in checkpointing

parent 91788719
......@@ -72,7 +72,7 @@ public:
class OPENMM_EXPORT OpenCLPlatform::PlatformData {
public:
PlatformData(int numParticles, const std::string& platformPropValue, const std::string& deviceIndexProperty);
PlatformData(const System& system, const std::string& platformPropValue, const std::string& deviceIndexProperty);
~PlatformData();
void initializeContexts(const System& system);
void syncContexts();
......
......@@ -65,7 +65,7 @@ static void CL_CALLBACK errorCallback(const char* errinfo, const void* private_i
std::cerr << "OpenCL internal error: " << errinfo << std::endl;
}
OpenCLContext::OpenCLContext(int numParticles, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData) :
OpenCLContext::OpenCLContext(const System& system, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData) :
time(0.0), platformData(platformData), stepCount(0), computeForceCount(0), atomsWereReordered(false), posq(NULL),
velm(NULL), forceBuffers(NULL), longForceBuffer(NULL), energyBuffer(NULL), atomIndex(NULL), integration(NULL),
bonded(NULL), nonbonded(NULL), thread(NULL) {
......@@ -207,8 +207,8 @@ OpenCLContext::OpenCLContext(int numParticles, int platformIndex, int deviceInde
cl_context_properties cprops[] = {CL_CONTEXT_PLATFORM, (cl_context_properties) platforms[platformIndex](), 0};
context = cl::Context(contextDevices, cprops, errorCallback);
queue = cl::CommandQueue(context, device);
numAtoms = numParticles;
paddedNumAtoms = TileSize*((numParticles+TileSize-1)/TileSize);
numAtoms = system.getNumParticles();
paddedNumAtoms = TileSize*((numAtoms+TileSize-1)/TileSize);
numAtomBlocks = (paddedNumAtoms+(TileSize-1))/TileSize;
numThreadBlocks = numThreadBlocksPerComputeUnit*device.getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>();
bonded = new OpenCLBondedUtilities(*this);
......@@ -268,6 +268,10 @@ OpenCLContext::OpenCLContext(int numParticles, int platformIndex, int deviceInde
// Create the work thread used for parallelization when running on multiple devices.
thread = new WorkThread();
// Create the integration utilities object.
integration = new OpenCLIntegrationUtilities(*this, system);
}
OpenCLContext::~OpenCLContext() {
......@@ -328,7 +332,6 @@ void OpenCLContext::initialize(const System& system) {
(*atomIndex)[i] = i;
atomIndex->upload();
findMoleculeGroups(system);
integration = new OpenCLIntegrationUtilities(*this, system);
nonbonded->initialize(system);
}
......
......@@ -146,7 +146,7 @@ public:
class ReorderListener;
static const int ThreadBlockSize;
static const int TileSize;
OpenCLContext(int numParticles, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData);
OpenCLContext(const System& system, int platformIndex, int deviceIndex, OpenCLPlatform::PlatformData& platformData);
~OpenCLContext();
/**
* This is called to initialize internal data structures after all Forces in the system
......@@ -473,6 +473,12 @@ public:
* assumes ownership of the object, and deletes it when the context itself is deleted.
*/
void addReorderListener(ReorderListener* listener);
/**
* Get the list of ReorderListeners.
*/
std::vector<ReorderListener*>& getReorderListeners() {
return reorderListeners;
}
private:
struct Molecule;
struct MoleculeGroup;
......
......@@ -592,7 +592,7 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
vsitePositionKernel.setArg<cl::Buffer>(6, vsiteOutOfPlaneWeights->getDeviceBuffer());
vsiteForceKernel = cl::Kernel(ccmaProgram, "distributeForces");
vsiteForceKernel.setArg<cl::Buffer>(0, context.getPosq().getDeviceBuffer());
vsiteForceKernel.setArg<cl::Buffer>(1, context.getForce().getDeviceBuffer());
// Skip argument 1: the force array hasn't been created yet.
vsiteForceKernel.setArg<cl::Buffer>(2, vsite2AvgAtoms->getDeviceBuffer());
vsiteForceKernel.setArg<cl::Buffer>(3, vsite2AvgWeights->getDeviceBuffer());
vsiteForceKernel.setArg<cl::Buffer>(4, vsite3AvgAtoms->getDeviceBuffer());
......@@ -763,8 +763,10 @@ void OpenCLIntegrationUtilities::computeVirtualSites() {
}
void OpenCLIntegrationUtilities::distributeForcesFromVirtualSites() {
if (numVsites > 0)
if (numVsites > 0) {
vsiteForceKernel.setArg<cl::Buffer>(1, context.getForce().getDeviceBuffer());
context.executeKernel(vsiteForceKernel, numVsites);
}
}
void OpenCLIntegrationUtilities::initRandomNumberGenerator(unsigned int randomNumberSeed) {
......
......@@ -231,6 +231,8 @@ void OpenCLUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream
stream.write((char*) &cl.getPosq()[0], sizeof(mm_float4)*cl.getPosq().getSize());
cl.getVelm().download();
stream.write((char*) &cl.getVelm()[0], sizeof(mm_float4)*cl.getVelm().getSize());
stream.write((char*) &cl.getAtomIndex()[0], sizeof(cl_int)*cl.getAtomIndex().getSize());
stream.write((char*) &cl.getPosCellOffsets()[0], sizeof(mm_int4)*cl.getPosCellOffsets().size());
mm_float4 box = cl.getPeriodicBoxSize();
stream.write((char*) &box, sizeof(mm_float4));
cl.getIntegrationUtilities().createCheckpoint(stream);
......@@ -244,16 +246,24 @@ void OpenCLUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream&
throw OpenMMException("Checkpoint was created with a different version of OpenMM");
double time;
stream.read((char*) &time, sizeof(double));
cl.setTime(time);
vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
for (int i = 0; i < (int) contexts.size(); i++)
contexts[i]->setTime(time);
stream.read((char*) &cl.getPosq()[0], sizeof(mm_float4)*cl.getPosq().getSize());
cl.getPosq().upload();
stream.read((char*) &cl.getVelm()[0], sizeof(mm_float4)*cl.getVelm().getSize());
cl.getVelm().upload();
stream.read((char*) &cl.getAtomIndex()[0], sizeof(cl_int)*cl.getAtomIndex().getSize());
cl.getAtomIndex().upload();
stream.read((char*) &cl.getPosCellOffsets()[0], sizeof(mm_int4)*cl.getPosCellOffsets().size());
mm_float4 box;
stream.read((char*) &box, sizeof(mm_float4));
cl.setPeriodicBoxSize(box.x, box.y, box.z);
for (int i = 0; i < (int) contexts.size(); i++)
contexts[i]->setPeriodicBoxSize(box.x, box.y, box.z);
cl.getIntegrationUtilities().loadCheckpoint(stream);
SimTKOpenMMUtilities::loadCheckpoint(stream);
for (int i = 0; i < cl.getReorderListeners().size(); i++)
cl.getReorderListeners()[i]->execute();
}
void OpenCLApplyConstraintsKernel::initialize(const System& system) {
......@@ -4435,6 +4445,7 @@ void OpenCLApplyAndersenThermostatKernel::initialize(const System& system, const
defines["NUM_ATOMS"] = intToString(cl.getNumAtoms());
cl::Program program = cl.createProgram(OpenCLKernelSources::andersenThermostat, defines);
kernel = cl::Kernel(program, "applyAndersenThermostat");
cl.getIntegrationUtilities().initRandomNumberGenerator(randomSeed);
// Create the arrays with the group definitions.
......@@ -4451,7 +4462,6 @@ void OpenCLApplyAndersenThermostatKernel::initialize(const System& system, const
void OpenCLApplyAndersenThermostatKernel::execute(ContextImpl& context) {
if (!hasInitializedKernels) {
hasInitializedKernels = true;
cl.getIntegrationUtilities().initRandomNumberGenerator(randomSeed);
kernel.setArg<cl::Buffer>(2, cl.getVelm().getDeviceBuffer());
kernel.setArg<cl::Buffer>(3, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
kernel.setArg<cl::Buffer>(4, cl.getIntegrationUtilities().getRandom().getDeviceBuffer());
......
......@@ -101,8 +101,7 @@ void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, stri
getPropertyDefaultValue(OpenCLPlatformIndex()) : properties.find(OpenCLPlatformIndex())->second);
const string& devicePropValue = (properties.find(OpenCLDeviceIndex()) == properties.end() ?
getPropertyDefaultValue(OpenCLDeviceIndex()) : properties.find(OpenCLDeviceIndex())->second);
int numParticles = context.getSystem().getNumParticles();
context.setPlatformData(new PlatformData(numParticles, platformPropValue, devicePropValue));
context.setPlatformData(new PlatformData(context.getSystem(), platformPropValue, devicePropValue));
}
void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
......@@ -110,7 +109,7 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
delete data;
}
OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& platformPropValue, const string& deviceIndexProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) {
OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& platformPropValue, const string& deviceIndexProperty) : removeCM(false), stepCount(0), computeForceCount(0), time(0.0) {
int platformIndex = 0;
if (platformPropValue.length() > 0)
stringstream(platformPropValue) >> platformIndex;
......@@ -125,11 +124,11 @@ OpenCLPlatform::PlatformData::PlatformData(int numParticles, const string& platf
if (devices[i].length() > 0) {
unsigned int deviceIndex;
stringstream(devices[i]) >> deviceIndex;
contexts.push_back(new OpenCLContext(numParticles, platformIndex, deviceIndex, *this));
contexts.push_back(new OpenCLContext(system, platformIndex, deviceIndex, *this));
}
}
if (contexts.size() == 0)
contexts.push_back(new OpenCLContext(numParticles, platformIndex, -1, *this));
contexts.push_back(new OpenCLContext(system, platformIndex, -1, *this));
stringstream device;
for (int i = 0; i < (int) contexts.size(); i++) {
if (i > 0)
......
......@@ -68,8 +68,8 @@ void compareStates(State& s1, State& s2) {
}
void testCheckpoint() {
const int numParticles = 10;
const double boxSize = 3.0;
const int numParticles = 100;
const double boxSize = 5.0;
const double temperature = 200.0;
OpenCLPlatform platform;
System system;
......@@ -83,7 +83,16 @@ void testCheckpoint() {
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
nonbonded->addParticle(i%2 == 0 ? 0.1 : -0.1, 0.2, 0.1);
positions[i] = Vec3(boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt));
bool clash;
do {
clash = false;
positions[i] = Vec3(boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt));
for (int j = 0; j < i; j++) {
Vec3 delta = positions[i]-positions[j];
if (sqrt(delta.dot(delta)) < 0.1)
clash = true;
}
} while (clash);
}
VerletIntegrator integrator(0.001);
Context context(system, integrator, platform);
......@@ -119,6 +128,34 @@ void testCheckpoint() {
integrator.step(10);
State s4 = context.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s2, s4);
// Create a new Context that uses multiple devices.
string deviceIndex = platform.getPropertyValue(context, OpenCLPlatform::OpenCLDeviceIndex());
map<string, string> props;
props[OpenCLPlatform::OpenCLDeviceIndex()] = deviceIndex+","+deviceIndex;
VerletIntegrator integrator2(0.001);
Context context2(system, integrator2, platform, props);
context2.setPositions(positions);
context2.setPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
context2.setParameter(AndersenThermostat::Temperature(), temperature);
// Now repeat all of the above tests with it.
integrator2.step(100);
State s5 = context2.getState(State::Positions | State::Velocities | State::Parameters);
stringstream stream2(ios_base::out | ios_base::in | ios_base::binary);
context2.createCheckpoint(stream2);
integrator2.step(10);
State s6 = context2.getState(State::Positions | State::Velocities | State::Parameters);
context2.setPeriodicBoxVectors(Vec3(2*boxSize, 0, 0), Vec3(0, 2*boxSize, 0), Vec3(0, 0, 2*boxSize));
context2.setParameter(AndersenThermostat::Temperature(), temperature+10);
context2.loadCheckpoint(stream2);
State s7 = context2.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s5, s7);
integrator2.step(10);
State s8 = context2.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s6, s8);
}
int main() {
......
......@@ -51,7 +51,7 @@ using namespace std;
void testTransform() {
System system;
system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(1, "", "");
OpenCLPlatform::PlatformData platformData(system, "", "");
OpenCLContext& context = *platformData.contexts[0];
context.initialize(system);
OpenMM_SFMT::SFMT sfmt;
......
......@@ -48,7 +48,7 @@ void testGaussian() {
System system;
for (int i = 0; i < numAtoms; i++)
system.addParticle(1.0);
OpenCLPlatform::PlatformData platformData(numAtoms, "", "");
OpenCLPlatform::PlatformData platformData(system, "", "");
OpenCLContext& context = *platformData.contexts[0];
context.initialize(system);
context.getIntegrationUtilities().initRandomNumberGenerator(0);
......
......@@ -62,7 +62,7 @@ void verifySorting(vector<float> array) {
System system;
system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(1, "", "");
OpenCLPlatform::PlatformData platformData(system, "", "");
OpenCLContext& context = *platformData.contexts[0];
context.initialize(system);
OpenCLArray<float> data(context, array.size(), "sortData");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment