Unverified Commit b7ee3022 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Added option for CheckpointReporter to save States (#3252)

* Added option for CheckpointReporter to save States

* Try to fix test failure on Windows
parent 12b6c3c3
......@@ -6,8 +6,8 @@ Simbios, the NIH National Center for Physics-Based Simulation of
Biological Structures at Stanford, funded under the NIH Roadmap for
Medical Research, grant U54 GM072970. See https://simtk.org.
Portions copyright (c) 2014-2016 Stanford University and the Authors.
Authors: Robert McGibbon
Portions copyright (c) 2014-2021 Stanford University and the Authors.
Authors: Robert McGibbon, Peter Eastman
Contributors:
Permission is hereby granted, free of charge, to any person obtaining a
......@@ -32,6 +32,7 @@ from __future__ import absolute_import
__author__ = "Robert McGibbon"
__version__ = "1.0"
import openmm as mm
import os
import os.path
......@@ -41,14 +42,19 @@ __all__ = ['CheckpointReporter']
class CheckpointReporter(object):
"""CheckpointReporter saves periodic checkpoints of a simulation.
The checkpoints will overwrite one another -- only the last checkpoint
will be saved in the file.
will be saved in the file. Optionally you can saved serialized State
objects instead of checkpoints. This is a more portable but less
thorough way of recording the state of a simulation.
To use it, create a CheckpointReporter, then add it to the Simulation's
list of reporters. To load a checkpoint file and continue a simulation,
use the following recipe:
>>> with open('checkput.chk', 'rb') as f:
>>> simulation.context.loadCheckpoint(f.read())
>>> simulation.loadCheckpoint('checkpoint.chk')
Reloading a saved State can be done like this:
>>> simulation.loadState('state.xml')
Notes:
A checkpoint contains not only publicly visible data such as the particle
......@@ -69,20 +75,30 @@ class CheckpointReporter(object):
incompatible. If a checkpoint cannot be loaded, that is signaled by
throwing an exception.
In contrast, a State contains only the publicly visible data: positions,
velocities, global parameters, box vectors, etc. This makes it much more
portable. Reloading the State will put the Simulation back into approximately
the same state it had before, but you should not expect it to produce an
identical trajectory to the original Simulation.
"""
def __init__(self, file, reportInterval):
def __init__(self, file, reportInterval, writeState=False):
"""Create a CheckpointReporter.
Parameters
----------
file : string or open file object
The file to write to. Any current contents will be overwritten.
The file to write to. Any current contents will be overwritten. If this
is a file object, it should have been opened in binary mode if writeState
is false, or in text mode if writeState is true.
reportInterval : int
The interval (in time steps) at which to write checkpoints.
writeState : bool=False
If true, write serialized State objects. If false, write checkpoints.
"""
self._reportInterval = reportInterval
self._file = file
self._writeState = writeState
def describeNextReport(self, simulation):
"""Get information about the next report this object will generate.
......@@ -118,8 +134,10 @@ class CheckpointReporter(object):
tempFilename1 = self._file+".backup1"
tempFilename2 = self._file+".backup2"
with open(tempFilename1, 'w+b', 0) as out:
out.write(simulation.context.createCheckpoint())
if self._writeState:
simulation.saveState(tempFilename1)
else:
simulation.saveCheckpoint(tempFilename1)
exists = os.path.exists(self._file)
if exists:
os.rename(self._file, tempFilename2)
......@@ -130,7 +148,10 @@ class CheckpointReporter(object):
# Replace the contents of the file.
self._file.seek(0)
chk = simulation.context.createCheckpoint()
self._file.write(chk)
if self._writeState:
state = simulation.context.getState(getPositions=True, getVelocities=True, getParameters=True, getIntegratorParameters=True)
self._file.write(mm.XmlSerializer.serialize(state))
else:
self._file.write(simulation.context.createCheckpoint())
self._file.truncate()
self._file.flush()
......@@ -18,24 +18,28 @@ class TestCheckpointReporter(unittest.TestCase):
self.simulation.context.setPositions(pdb.positions)
def test_1(self):
file = tempfile.NamedTemporaryFile(delete=False)
self.simulation.reporters.append(app.CheckpointReporter(file, 1))
self.simulation.step(1)
# get the current positions
positions = self.simulation.context.getState(getPositions=True).getPositions()
# now set the positions into junk...
self.simulation.context.setPositions([mm.Vec3(0, 0, 0)] * len(positions))
# then reload the right positions from the checkpoint
file.close()
with open(file.name, 'rb') as f:
self.simulation.context.loadCheckpoint(f.read())
os.unlink(file.name)
newPositions = self.simulation.context.getState(getPositions=True).getPositions()
self.assertSequenceEqual(positions, newPositions)
"""Test CheckpointReporter."""
for writeState in [True, False]:
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, 'checkpoint')
self.simulation.reporters.clear()
self.simulation.reporters.append(app.CheckpointReporter(filename, 1, writeState=writeState))
self.simulation.step(1)
# get the current positions
positions = self.simulation.context.getState(getPositions=True).getPositions()
# now set the positions into junk...
self.simulation.context.setPositions([mm.Vec3(0, 0, 0)] * len(positions))
# then reload the right positions from the checkpoint
if writeState:
self.simulation.loadState(filename)
else:
self.simulation.loadCheckpoint(filename)
newPositions = self.simulation.context.getState(getPositions=True).getPositions()
self.assertSequenceEqual(positions, newPositions)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment