Commit 188bf2ef authored by Peter Eastman's avatar Peter Eastman
Browse files

Implemented checkpointing

parent 8d2074f5
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#define SFMT_H #define SFMT_H
#include <stdio.h> #include <stdio.h>
#include <iosfwd>
#include "openmm/internal/windowsExport.h" #include "openmm/internal/windowsExport.h"
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)
...@@ -101,6 +102,8 @@ public: ...@@ -101,6 +102,8 @@ public:
~SFMT() { ~SFMT() {
deleteSFMTData(data); deleteSFMTData(data);
} }
void createCheckpoint(std::ostream& stream);
void loadCheckpoint(std::istream& stream);
SFMTData* data; SFMTData* data;
}; };
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <cstring> #include <cstring>
#include <cassert> #include <cassert>
#include <iostream>
#if defined(__BIG_ENDIAN__) && !defined(__amd64) && !defined(BIG_ENDIAN64) #if defined(__BIG_ENDIAN__) && !defined(__amd64) && !defined(BIG_ENDIAN64)
#define BIG_ENDIAN64 1 #define BIG_ENDIAN64 1
...@@ -117,6 +118,16 @@ public: ...@@ -117,6 +118,16 @@ public:
} }
}; };
void SFMT::createCheckpoint(std::ostream& stream) {
stream.write((char*) &data->sfmt, sizeof(data->sfmt));
stream.write((char*) &data->idx, sizeof(data->idx));
}
void SFMT::loadCheckpoint(std::istream& stream) {
stream.read((char*) &data->sfmt, sizeof(data->sfmt));
stream.read((char*) &data->idx, sizeof(data->idx));
}
/*---------------- /*----------------
STATIC FUNCTIONS STATIC FUNCTIONS
----------------*/ ----------------*/
......
...@@ -59,6 +59,7 @@ ...@@ -59,6 +59,7 @@
#include "openmm/VariableLangevinIntegrator.h" #include "openmm/VariableLangevinIntegrator.h"
#include "openmm/VariableVerletIntegrator.h" #include "openmm/VariableVerletIntegrator.h"
#include "openmm/VerletIntegrator.h" #include "openmm/VerletIntegrator.h"
#include <iosfwd>
#include <set> #include <set>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -184,6 +185,18 @@ public: ...@@ -184,6 +185,18 @@ public:
* @param c the vector defining the third edge of the periodic box * @param c the vector defining the third edge of the periodic box
*/ */
virtual void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const = 0; virtual void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const = 0;
/**
* Create a checkpoint recording the current state of the Context.
*
* @param stream an output stream the checkpoint data should be written to
*/
virtual void createCheckpoint(ContextImpl& context, std::ostream& stream) = 0;
/**
* Load a checkpoint that was written by createCheckpoint().
*
* @param stream an input stream the checkpoint data should be read from
*/
virtual void loadCheckpoint(ContextImpl& context, std::istream& stream) = 0;
}; };
/** /**
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008 Stanford University and the Authors. * * Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "Integrator.h" #include "Integrator.h"
#include "State.h" #include "State.h"
#include "System.h" #include "System.h"
#include <iosfwd>
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -196,6 +197,33 @@ public: ...@@ -196,6 +197,33 @@ public:
* This is an expensive operation, so you should try to avoid calling it too frequently. * This is an expensive operation, so you should try to avoid calling it too frequently.
*/ */
void reinitialize(); void reinitialize();
/**
* Create a checkpoint recording the current state of the Context. This should be treated
* as an opaque block of binary data. See loadCheckpoint() for more details.
*
* @param stream an output stream the checkpoint data should be written to
*/
void createCheckpoint(std::ostream& stream);
/**
* Load a checkpoint that was written by createCheckpoint().
*
* A checkpoint contains not only publicly visible data such as the particle positions and
* velocities, but also internal data such as the states of random number generators. Ideally,
* loading a checkpoint should restore the Context to an identical state to when it was written,
* such that continuing the simulation will produce an identical trajectory. This is not strictly
* guaranteed to be true, however, and should not be relied on. For most purposes, however, the
* internal state should be close enough to be reasonably considered equivalent.
*
* A checkpoint contains data that is highly specific to the Context from which it was created.
* It depends on the details of the System, the Platform being used, and the hardware and software
* of the computer it was created on. If you try to load it on a computer with different hardware,
* or for a System that is different in any way, loading is likely to fail. Checkpoints created
* with different versions of OpenMM are also often incompatible. If a checkpoint cannot be loaded,
* that is signaled by throwing an exception.
*
* @param stream an input stream the checkpoint data should be read from
*/
void loadCheckpoint(std::istream& stream);
private: private:
friend class Platform; friend class Platform;
ContextImpl* impl; ContextImpl* impl;
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008 Stanford University and the Authors. * * Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "openmm/Kernel.h" #include "openmm/Kernel.h"
#include "openmm/Platform.h" #include "openmm/Platform.h"
#include "openmm/Vec3.h" #include "openmm/Vec3.h"
#include <iosfwd>
#include <map> #include <map>
#include <vector> #include <vector>
...@@ -220,6 +221,18 @@ public: ...@@ -220,6 +221,18 @@ public:
* same molecule if they are connected by constraints or bonds. * same molecule if they are connected by constraints or bonds.
*/ */
const std::vector<std::vector<int> >& getMolecules() const; const std::vector<std::vector<int> >& getMolecules() const;
/**
* Create a checkpoint recording the current state of the Context.
*
* @param stream an output stream the checkpoint data should be written to
*/
void createCheckpoint(std::ostream& stream);
/**
* Load a checkpoint that was written by createCheckpoint().
*
* @param stream an input stream the checkpoint data should be read from
*/
void loadCheckpoint(std::istream& stream);
private: private:
friend class Context; friend class Context;
static void tagParticlesInMolecule(int particle, int molecule, std::vector<int>& particleMolecule, std::vector<std::vector<int> >& particleBonds); static void tagParticlesInMolecule(int particle, int molecule, std::vector<int>& particleMolecule, std::vector<std::vector<int> >& particleBonds);
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2009 Stanford University and the Authors. * * Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -177,3 +177,11 @@ void Context::reinitialize() { ...@@ -177,3 +177,11 @@ void Context::reinitialize() {
delete impl; delete impl;
impl = new ContextImpl(*this, system, integrator, &platform, properties); impl = new ContextImpl(*this, system, integrator, &platform, properties);
} }
void Context::createCheckpoint(ostream& stream) {
impl->createCheckpoint(stream);
}
void Context::loadCheckpoint(istream& stream) {
impl->loadCheckpoint(stream);
}
...@@ -38,15 +38,14 @@ ...@@ -38,15 +38,14 @@
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "openmm/State.h" #include "openmm/State.h"
#include "openmm/VirtualSite.h" #include "openmm/VirtualSite.h"
#include "openmm/Context.h"
#include <iostream>
#include <map> #include <map>
#include <utility> #include <utility>
#include <vector> #include <vector>
using namespace OpenMM; using namespace OpenMM;
using std::map; using namespace std;
using std::pair;
using std::vector;
using std::string;
ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator, Platform* platform, const map<string, string>& properties) : ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator, Platform* platform, const map<string, string>& properties) :
owner(owner), system(system), integrator(integrator), hasInitializedForces(false), lastForceGroups(-1), platform(platform), platformData(NULL) { owner(owner), system(system), integrator(integrator), hasInitializedForces(false), lastForceGroups(-1), platform(platform), platformData(NULL) {
...@@ -288,3 +287,50 @@ void ContextImpl::tagParticlesInMolecule(int particle, int molecule, vector<int> ...@@ -288,3 +287,50 @@ void ContextImpl::tagParticlesInMolecule(int particle, int molecule, vector<int>
if (particleMolecule[particleBonds[particle][i]] == -1) if (particleMolecule[particleBonds[particle][i]] == -1)
tagParticlesInMolecule(particleBonds[particle][i], molecule, particleMolecule, particleBonds); tagParticlesInMolecule(particleBonds[particle][i], molecule, particleMolecule, particleBonds);
} }
static void writeString(ostream& stream, string str) {
int length = str.size();
stream.write((char*) &length, sizeof(int));
stream.write((char*) &str[0], length);
}
static string readString(istream& stream) {
int length;
stream.read((char*) &length, sizeof(int));
string str(length, ' ');
stream.read((char*) &str[0], length);
return str;
}
void ContextImpl::createCheckpoint(ostream& stream) {
writeString(stream, getPlatform().getName());
int numParticles = getSystem().getNumParticles();
stream.write((char*) &numParticles, sizeof(int));
int numParameters = parameters.size();
stream.write((char*) &numParameters, sizeof(int));
for (map<string, double>::const_iterator iter = parameters.begin(); iter != parameters.end(); ++iter) {
writeString(stream, iter->first);
stream.write((char*) &iter->second, sizeof(double));
}
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).createCheckpoint(*this, stream);
stream.flush();
}
void ContextImpl::loadCheckpoint(istream& stream) {
string platformName = readString(stream);
if (platformName != getPlatform().getName())
throw OpenMMException("loadCheckpoint: Checkpoint was created with a different Platform: "+platformName);
int numParticles;
stream.read((char*) &numParticles, sizeof(int));
if (numParticles != getSystem().getNumParticles())
throw OpenMMException("loadCheckpoint: Checkpoint contains the wrong number of particles");
int numParameters;
stream.read((char*) &numParameters, sizeof(int));
for (int i = 0; i < numParameters; i++) {
string name = readString(stream);
double value;
stream.read((char*) &value, sizeof(double));
parameters[name] = value;
}
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).loadCheckpoint(*this, stream);
}
...@@ -184,6 +184,14 @@ void CudaUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, cons ...@@ -184,6 +184,14 @@ void CudaUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, cons
gpuSetConstants(gpu); gpuSetConstants(gpu);
} }
void CudaUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& stream) {
throw OpenMMException("CudaPlatform does not support checkpointing");
}
void CudaUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
throw OpenMMException("CudaPlatform does not support checkpointing");
}
void CudaApplyConstraintsKernel::initialize(const System& system) { void CudaApplyConstraintsKernel::initialize(const System& system) {
} }
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2009 Stanford University and the Authors. * * Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -157,6 +157,18 @@ public: ...@@ -157,6 +157,18 @@ public:
* @param c the vector defining the third edge of the periodic box * @param c the vector defining the third edge of the periodic box
*/ */
void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const; void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const;
/**
* Create a checkpoint recording the current state of the Context.
*
* @param stream an output stream the checkpoint data should be written to
*/
void createCheckpoint(ContextImpl& context, std::ostream& stream);
/**
* Load a checkpoint that was written by createCheckpoint().
*
* @param stream an input stream the checkpoint data should be read from
*/
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private: private:
CudaPlatform::PlatformData& data; CudaPlatform::PlatformData& data;
}; };
......
...@@ -816,3 +816,23 @@ int OpenCLIntegrationUtilities::prepareRandomNumbers(int numValues) { ...@@ -816,3 +816,23 @@ int OpenCLIntegrationUtilities::prepareRandomNumbers(int numValues) {
randomPos = numValues; randomPos = numValues;
return 0; return 0;
} }
void OpenCLIntegrationUtilities::createCheckpoint(ostream& stream) {
stream.write((char*) &randomPos, sizeof(int));
vector<mm_float4> randomVec;
random->download(randomVec);
stream.write((char*) &randomVec[0], sizeof(mm_float4)*random->getSize());
vector<mm_int4> randomSeedVec;
randomSeed->download(randomSeedVec);
stream.write((char*) &randomSeedVec[0], sizeof(mm_int4)*randomSeed->getSize());
}
void OpenCLIntegrationUtilities::loadCheckpoint(istream& stream) {
stream.read((char*) &randomPos, sizeof(int));
vector<mm_float4> randomVec(random->getSize());
stream.read((char*) &randomVec[0], sizeof(mm_float4)*random->getSize());
random->upload(randomVec);
vector<mm_int4> randomSeedVec(randomSeed->getSize());
stream.read((char*) &randomSeedVec[0], sizeof(mm_int4)*randomSeed->getSize());
randomSeed->upload(randomSeedVec);
}
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "openmm/System.h" #include "openmm/System.h"
#include "OpenCLContext.h" #include "OpenCLContext.h"
#include "openmm/internal/windowsExport.h" #include "openmm/internal/windowsExport.h"
#include <iosfwd>
namespace OpenMM { namespace OpenMM {
...@@ -92,6 +93,18 @@ public: ...@@ -92,6 +93,18 @@ public:
* Distribute forces from virtual sites to the atoms they are based on. * Distribute forces from virtual sites to the atoms they are based on.
*/ */
void distributeForcesFromVirtualSites(); void distributeForcesFromVirtualSites();
/**
* Create a checkpoint recording the current state of the random number generator.
*
* @param stream an output stream the checkpoint data should be written to
*/
void createCheckpoint(std::ostream& stream);
/**
* Load a checkpoint that was written by createCheckpoint().
*
* @param stream an input stream the checkpoint data should be read from
*/
void loadCheckpoint(std::istream& stream);
private: private:
void applyConstraints(bool constrainVelocities, double tol); void applyConstraints(bool constrainVelocities, double tol);
OpenCLContext& context; OpenCLContext& context;
......
...@@ -222,6 +222,38 @@ void OpenCLUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, co ...@@ -222,6 +222,38 @@ void OpenCLUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, co
contexts[i]->setPeriodicBoxSize(a[0], b[1], c[2]); contexts[i]->setPeriodicBoxSize(a[0], b[1], c[2]);
} }
void OpenCLUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& stream) {
int version = 1;
stream.write((char*) &version, sizeof(int));
double time = cl.getTime();
stream.write((char*) &time, sizeof(double));
cl.getPosq().download();
stream.write((char*) &cl.getPosq()[0], sizeof(mm_float4)*cl.getPosq().getSize());
cl.getVelm().download();
stream.write((char*) &cl.getVelm()[0], sizeof(mm_float4)*cl.getVelm().getSize());
mm_float4 box = cl.getPeriodicBoxSize();
stream.write((char*) &box, sizeof(mm_float4));
cl.getIntegrationUtilities().createCheckpoint(stream);
}
void OpenCLUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
int version;
stream.read((char*) &version, sizeof(int));
if (version != 1)
throw OpenMMException("Checkpoint was created with a different version of OpenMM");
double time;
stream.read((char*) &time, sizeof(double));
cl.setTime(time);
stream.read((char*) &cl.getPosq()[0], sizeof(mm_float4)*cl.getPosq().getSize());
cl.getPosq().upload();
stream.read((char*) &cl.getVelm()[0], sizeof(mm_float4)*cl.getVelm().getSize());
cl.getVelm().upload();
mm_float4 box;
stream.read((char*) &box, sizeof(mm_float4));
cl.setPeriodicBoxSize(box.x, box.y, box.z);
cl.getIntegrationUtilities().loadCheckpoint(stream);
}
void OpenCLApplyConstraintsKernel::initialize(const System& system) { void OpenCLApplyConstraintsKernel::initialize(const System& system) {
} }
......
...@@ -152,6 +152,18 @@ public: ...@@ -152,6 +152,18 @@ public:
* @param c the vector defining the third edge of the periodic box * @param c the vector defining the third edge of the periodic box
*/ */
void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const; void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const;
/**
* Create a checkpoint recording the current state of the Context.
*
* @param stream an output stream the checkpoint data should be written to
*/
void createCheckpoint(ContextImpl& context, std::ostream& stream);
/**
* Load a checkpoint that was written by createCheckpoint().
*
* @param stream an input stream the checkpoint data should be read from
*/
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private: private:
OpenCLContext& cl; OpenCLContext& cl;
}; };
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2012 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
/**
* This tests creating and loading checkpoints with the OpenCL platform.
*/
#include "OpenCLPlatform.h"
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/AndersenThermostat.h"
#include "openmm/Context.h"
#include "openmm/NonbondedForce.h"
#include "openmm/System.h"
#include "openmm/VerletIntegrator.h"
#include "sfmt/SFMT.h"
#include <iostream>
#include <sstream>
#include <vector>
using namespace OpenMM;
using namespace std;
const double TOL = 1e-5;
void compareStates(State& s1, State& s2) {
ASSERT_EQUAL_TOL(s1.getTime(), s2.getTime(), TOL);
int numParticles = s1.getPositions().size();
for (int i = 0; i < numParticles; i++) {
ASSERT_EQUAL_VEC(s1.getPositions()[i], s2.getPositions()[i], TOL);
ASSERT_EQUAL_VEC(s1.getVelocities()[i], s2.getVelocities()[i], TOL);
Vec3 a1, b1, c1, a2, b2, c2;
s1.getPeriodicBoxVectors(a1, b1, c1);
s2.getPeriodicBoxVectors(a2, b2, c2);
ASSERT_EQUAL_VEC(a1, a2, TOL);
ASSERT_EQUAL_VEC(b1, b2, TOL);
ASSERT_EQUAL_VEC(c1, c2, TOL);
for (map<string, double>::const_iterator iter = s1.getParameters().begin(); iter != s1.getParameters().end(); ++iter)
ASSERT_EQUAL(iter->second, s2.getParameters().at(iter->first));
}
}
void testCheckpoint() {
const int numParticles = 10;
const double boxSize = 3.0;
const double temperature = 200.0;
OpenCLPlatform platform;
System system;
system.addForce(new AndersenThermostat(0.0, 100.0));
NonbondedForce* nonbonded = new NonbondedForce();
system.addForce(nonbonded);
nonbonded->setNonbondedMethod(NonbondedForce::CutoffPeriodic);
vector<Vec3> positions(numParticles);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
nonbonded->addParticle(i%2 == 0 ? 0.1 : -0.1, 0.2, 0.1);
positions[i] = Vec3(boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt));
}
VerletIntegrator integrator(0.001);
Context context(system, integrator, platform);
context.setPositions(positions);
context.setPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
context.setParameter(AndersenThermostat::Temperature(), temperature);
// Run for a little while.
integrator.step(100);
// Record the current state and make a checkpoint.
State s1 = context.getState(State::Positions | State::Velocities | State::Parameters);
stringstream stream1(ios_base::out | ios_base::in | ios_base::binary);
context.createCheckpoint(stream1);
// Continue the simulation for a few more steps and record the state again.
integrator.step(10);
State s2 = context.getState(State::Positions | State::Velocities | State::Parameters);
// Restore from the checkpoint and see if everything gets restored correctly.
context.setPeriodicBoxVectors(Vec3(2*boxSize, 0, 0), Vec3(0, 2*boxSize, 0), Vec3(0, 0, 2*boxSize));
context.setParameter(AndersenThermostat::Temperature(), temperature+10);
context.loadCheckpoint(stream1);
State s3 = context.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s1, s3);
// Now simulate from there and see if the trajectory is identical.
integrator.step(10);
State s4 = context.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s2, s4);
}
int main() {
try {
testCheckpoint();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}
...@@ -75,6 +75,7 @@ ...@@ -75,6 +75,7 @@
#include "lepton/Parser.h" #include "lepton/Parser.h"
#include "lepton/ParsedExpression.h" #include "lepton/ParsedExpression.h"
#include <cmath> #include <cmath>
#include <iostream>
#include <limits> #include <limits>
using namespace OpenMM; using namespace OpenMM;
...@@ -238,6 +239,34 @@ void ReferenceUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, ...@@ -238,6 +239,34 @@ void ReferenceUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context,
box[2] = (RealOpenMM) c[2]; box[2] = (RealOpenMM) c[2];
} }
void ReferenceUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& stream) {
int version = 1;
stream.write((char*) &version, sizeof(int));
stream.write((char*) &data.time, sizeof(data.time));
vector<RealVec>& posData = extractPositions(context);
stream.write((char*) &posData[0], sizeof(RealVec)*posData.size());
vector<RealVec>& velData = extractVelocities(context);
stream.write((char*) &velData[0], sizeof(RealVec)*velData.size());
RealVec& box = extractBoxSize(context);
stream.write((char*) &box, sizeof(RealVec));
SimTKOpenMMUtilities::createCheckpoint(stream);
}
void ReferenceUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
int version;
stream.read((char*) &version, sizeof(int));
if (version != 1)
throw OpenMMException("Checkpoint was created with a different version of OpenMM");
stream.read((char*) &data.time, sizeof(data.time));
vector<RealVec>& posData = extractPositions(context);
stream.read((char*) &posData[0], sizeof(RealVec)*posData.size());
vector<RealVec>& velData = extractVelocities(context);
stream.read((char*) &velData[0], sizeof(RealVec)*velData.size());
RealVec& box = extractBoxSize(context);
stream.read((char*) &box, sizeof(RealVec));
SimTKOpenMMUtilities::loadCheckpoint(stream);
}
void ReferenceApplyConstraintsKernel::initialize(const System& system) { void ReferenceApplyConstraintsKernel::initialize(const System& system) {
int numParticles = system.getNumParticles(); int numParticles = system.getNumParticles();
masses.resize(numParticles); masses.resize(numParticles);
......
...@@ -168,6 +168,18 @@ public: ...@@ -168,6 +168,18 @@ public:
* @param c the vector defining the third edge of the periodic box * @param c the vector defining the third edge of the periodic box
*/ */
void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const; void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const;
/**
* Create a checkpoint recording the current state of the Context.
*
* @param stream an output stream the checkpoint data should be written to
*/
void createCheckpoint(ContextImpl& context, std::ostream& stream);
/**
* Load a checkpoint that was written by createCheckpoint().
*
* @param stream an input stream the checkpoint data should be read from
*/
void loadCheckpoint(ContextImpl& context, std::istream& stream);
private: private:
ReferencePlatform::PlatformData& data; ReferencePlatform::PlatformData& data;
}; };
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include <cmath> #include <cmath>
#include <cstdio> #include <cstdio>
#include <string.h> #include <string.h>
#include <iostream>
uint32_t SimTKOpenMMUtilities::_randomNumberSeed = 0; uint32_t SimTKOpenMMUtilities::_randomNumberSeed = 0;
bool SimTKOpenMMUtilities::_randomInitialized = false; bool SimTKOpenMMUtilities::_randomInitialized = false;
...@@ -363,3 +364,23 @@ void SimTKOpenMMUtilities::setRandomNumberSeed( uint32_t seed ) { ...@@ -363,3 +364,23 @@ void SimTKOpenMMUtilities::setRandomNumberSeed( uint32_t seed ) {
_randomNumberSeed = seed; _randomNumberSeed = seed;
_randomInitialized = false; _randomInitialized = false;
} }
void SimTKOpenMMUtilities::createCheckpoint(std::ostream& stream) {
stream.write((char*) &_randomNumberSeed, sizeof(uint32_t));
stream.write((char*) &_randomInitialized, sizeof(bool));
if (_randomInitialized) {
stream.write((char*) &nextGaussianIsValid, sizeof(bool));
stream.write((char*) &nextGaussian, sizeof(RealOpenMM));
sfmt.createCheckpoint(stream);
}
}
void SimTKOpenMMUtilities::loadCheckpoint(std::istream& stream) {
stream.read((char*) &_randomNumberSeed, sizeof(uint32_t));
stream.read((char*) &_randomInitialized, sizeof(bool));
if (_randomInitialized) {
stream.read((char*) &nextGaussianIsValid, sizeof(bool));
stream.read((char*) &nextGaussian, sizeof(RealOpenMM));
sfmt.loadCheckpoint(stream);
}
}
...@@ -199,6 +199,25 @@ class OPENMM_EXPORT SimTKOpenMMUtilities { ...@@ -199,6 +199,25 @@ class OPENMM_EXPORT SimTKOpenMMUtilities {
static void setRandomNumberSeed( uint32_t seed ); static void setRandomNumberSeed( uint32_t seed );
/**---------------------------------------------------------------------------------------
Write out the internal state of the random number generator.
@param stream a stream to write the checkpoint to
--------------------------------------------------------------------------------------- */
static void createCheckpoint(std::ostream& stream);
/**---------------------------------------------------------------------------------------
Load a checkpoint created by createCheckpoint().
@param stream a stream to load the checkpoint from
--------------------------------------------------------------------------------------- */
static void loadCheckpoint(std::istream& stream);
}; };
// --------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2012 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
/**
* This tests creating and loading checkpoints with the reference platform.
*/
#include "ReferencePlatform.h"
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/AndersenThermostat.h"
#include "openmm/Context.h"
#include "openmm/NonbondedForce.h"
#include "openmm/System.h"
#include "openmm/VerletIntegrator.h"
#include "sfmt/SFMT.h"
#include <iostream>
#include <sstream>
#include <vector>
using namespace OpenMM;
using namespace std;
const double TOL = 1e-5;
void compareStates(State& s1, State& s2) {
ASSERT_EQUAL_TOL(s1.getTime(), s2.getTime(), TOL);
int numParticles = s1.getPositions().size();
for (int i = 0; i < numParticles; i++) {
ASSERT_EQUAL_VEC(s1.getPositions()[i], s2.getPositions()[i], TOL);
ASSERT_EQUAL_VEC(s1.getVelocities()[i], s2.getVelocities()[i], TOL);
Vec3 a1, b1, c1, a2, b2, c2;
s1.getPeriodicBoxVectors(a1, b1, c1);
s2.getPeriodicBoxVectors(a2, b2, c2);
ASSERT_EQUAL_VEC(a1, a2, TOL);
ASSERT_EQUAL_VEC(b1, b2, TOL);
ASSERT_EQUAL_VEC(c1, c2, TOL);
for (map<string, double>::const_iterator iter = s1.getParameters().begin(); iter != s1.getParameters().end(); ++iter)
ASSERT_EQUAL(iter->second, s2.getParameters().at(iter->first));
}
}
void testCheckpoint() {
const int numParticles = 10;
const double boxSize = 3.0;
const double temperature = 200.0;
ReferencePlatform platform;
System system;
system.addForce(new AndersenThermostat(0.0, 100.0));
NonbondedForce* nonbonded = new NonbondedForce();
system.addForce(nonbonded);
nonbonded->setNonbondedMethod(NonbondedForce::CutoffPeriodic);
vector<Vec3> positions(numParticles);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
nonbonded->addParticle(i%2 == 0 ? 0.1 : -0.1, 0.2, 0.1);
positions[i] = Vec3(boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt), boxSize*genrand_real2(sfmt));
}
VerletIntegrator integrator(0.001);
Context context(system, integrator, platform);
context.setPositions(positions);
context.setPeriodicBoxVectors(Vec3(boxSize, 0, 0), Vec3(0, boxSize, 0), Vec3(0, 0, boxSize));
context.setParameter(AndersenThermostat::Temperature(), temperature);
// Run for a little while.
integrator.step(100);
// Record the current state and make a checkpoint.
State s1 = context.getState(State::Positions | State::Velocities | State::Parameters);
stringstream stream1(ios_base::out | ios_base::in | ios_base::binary);
context.createCheckpoint(stream1);
// Continue the simulation for a few more steps and record the state again.
integrator.step(10);
State s2 = context.getState(State::Positions | State::Velocities | State::Parameters);
// Restore from the checkpoint and see if everything gets restored correctly.
context.setPeriodicBoxVectors(Vec3(2*boxSize, 0, 0), Vec3(0, 2*boxSize, 0), Vec3(0, 0, 2*boxSize));
context.setParameter(AndersenThermostat::Temperature(), temperature+10);
context.loadCheckpoint(stream1);
State s3 = context.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s1, s3);
// Now simulate from there and see if the trajectory is identical.
integrator.step(10);
State s4 = context.getState(State::Positions | State::Velocities | State::Parameters);
compareStates(s2, s4);
}
int main() {
try {
testCheckpoint();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment