"vscode:/vscode.git/clone" did not exist on "5d4f86a466e32aff8acc18f4f532bcbdbd8056e0"
Unverified Commit b7ee3022 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Added option for CheckpointReporter to save States (#3252)

* Added option for CheckpointReporter to save States

* Try to fix test failure on Windows
parent 12b6c3c3
...@@ -6,8 +6,8 @@ Simbios, the NIH National Center for Physics-Based Simulation of ...@@ -6,8 +6,8 @@ Simbios, the NIH National Center for Physics-Based Simulation of
Biological Structures at Stanford, funded under the NIH Roadmap for Biological Structures at Stanford, funded under the NIH Roadmap for
Medical Research, grant U54 GM072970. See https://simtk.org. Medical Research, grant U54 GM072970. See https://simtk.org.
Portions copyright (c) 2014-2016 Stanford University and the Authors. Portions copyright (c) 2014-2021 Stanford University and the Authors.
Authors: Robert McGibbon Authors: Robert McGibbon, Peter Eastman
Contributors: Contributors:
Permission is hereby granted, free of charge, to any person obtaining a Permission is hereby granted, free of charge, to any person obtaining a
...@@ -32,6 +32,7 @@ from __future__ import absolute_import ...@@ -32,6 +32,7 @@ from __future__ import absolute_import
__author__ = "Robert McGibbon" __author__ = "Robert McGibbon"
__version__ = "1.0" __version__ = "1.0"
import openmm as mm
import os import os
import os.path import os.path
...@@ -41,14 +42,19 @@ __all__ = ['CheckpointReporter'] ...@@ -41,14 +42,19 @@ __all__ = ['CheckpointReporter']
class CheckpointReporter(object): class CheckpointReporter(object):
"""CheckpointReporter saves periodic checkpoints of a simulation. """CheckpointReporter saves periodic checkpoints of a simulation.
The checkpoints will overwrite one another -- only the last checkpoint The checkpoints will overwrite one another -- only the last checkpoint
will be saved in the file. will be saved in the file. Optionally you can saved serialized State
objects instead of checkpoints. This is a more portable but less
thorough way of recording the state of a simulation.
To use it, create a CheckpointReporter, then add it to the Simulation's To use it, create a CheckpointReporter, then add it to the Simulation's
list of reporters. To load a checkpoint file and continue a simulation, list of reporters. To load a checkpoint file and continue a simulation,
use the following recipe: use the following recipe:
>>> with open('checkput.chk', 'rb') as f: >>> simulation.loadCheckpoint('checkpoint.chk')
>>> simulation.context.loadCheckpoint(f.read())
Reloading a saved State can be done like this:
>>> simulation.loadState('state.xml')
Notes: Notes:
A checkpoint contains not only publicly visible data such as the particle A checkpoint contains not only publicly visible data such as the particle
...@@ -69,20 +75,30 @@ class CheckpointReporter(object): ...@@ -69,20 +75,30 @@ class CheckpointReporter(object):
incompatible. If a checkpoint cannot be loaded, that is signaled by incompatible. If a checkpoint cannot be loaded, that is signaled by
throwing an exception. throwing an exception.
In contrast, a State contains only the publicly visible data: positions,
velocities, global parameters, box vectors, etc. This makes it much more
portable. Reloading the State will put the Simulation back into approximately
the same state it had before, but you should not expect it to produce an
identical trajectory to the original Simulation.
""" """
def __init__(self, file, reportInterval): def __init__(self, file, reportInterval, writeState=False):
"""Create a CheckpointReporter. """Create a CheckpointReporter.
Parameters Parameters
---------- ----------
file : string or open file object file : string or open file object
The file to write to. Any current contents will be overwritten. The file to write to. Any current contents will be overwritten. If this
is a file object, it should have been opened in binary mode if writeState
is false, or in text mode if writeState is true.
reportInterval : int reportInterval : int
The interval (in time steps) at which to write checkpoints. The interval (in time steps) at which to write checkpoints.
writeState : bool=False
If true, write serialized State objects. If false, write checkpoints.
""" """
self._reportInterval = reportInterval self._reportInterval = reportInterval
self._file = file self._file = file
self._writeState = writeState
def describeNextReport(self, simulation): def describeNextReport(self, simulation):
"""Get information about the next report this object will generate. """Get information about the next report this object will generate.
...@@ -118,8 +134,10 @@ class CheckpointReporter(object): ...@@ -118,8 +134,10 @@ class CheckpointReporter(object):
tempFilename1 = self._file+".backup1" tempFilename1 = self._file+".backup1"
tempFilename2 = self._file+".backup2" tempFilename2 = self._file+".backup2"
with open(tempFilename1, 'w+b', 0) as out: if self._writeState:
out.write(simulation.context.createCheckpoint()) simulation.saveState(tempFilename1)
else:
simulation.saveCheckpoint(tempFilename1)
exists = os.path.exists(self._file) exists = os.path.exists(self._file)
if exists: if exists:
os.rename(self._file, tempFilename2) os.rename(self._file, tempFilename2)
...@@ -130,7 +148,10 @@ class CheckpointReporter(object): ...@@ -130,7 +148,10 @@ class CheckpointReporter(object):
# Replace the contents of the file. # Replace the contents of the file.
self._file.seek(0) self._file.seek(0)
chk = simulation.context.createCheckpoint() if self._writeState:
self._file.write(chk) state = simulation.context.getState(getPositions=True, getVelocities=True, getParameters=True, getIntegratorParameters=True)
self._file.write(mm.XmlSerializer.serialize(state))
else:
self._file.write(simulation.context.createCheckpoint())
self._file.truncate() self._file.truncate()
self._file.flush() self._file.flush()
...@@ -18,8 +18,12 @@ class TestCheckpointReporter(unittest.TestCase): ...@@ -18,8 +18,12 @@ class TestCheckpointReporter(unittest.TestCase):
self.simulation.context.setPositions(pdb.positions) self.simulation.context.setPositions(pdb.positions)
def test_1(self): def test_1(self):
file = tempfile.NamedTemporaryFile(delete=False) """Test CheckpointReporter."""
self.simulation.reporters.append(app.CheckpointReporter(file, 1)) for writeState in [True, False]:
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, 'checkpoint')
self.simulation.reporters.clear()
self.simulation.reporters.append(app.CheckpointReporter(filename, 1, writeState=writeState))
self.simulation.step(1) self.simulation.step(1)
# get the current positions # get the current positions
...@@ -29,10 +33,10 @@ class TestCheckpointReporter(unittest.TestCase): ...@@ -29,10 +33,10 @@ class TestCheckpointReporter(unittest.TestCase):
self.simulation.context.setPositions([mm.Vec3(0, 0, 0)] * len(positions)) self.simulation.context.setPositions([mm.Vec3(0, 0, 0)] * len(positions))
# then reload the right positions from the checkpoint # then reload the right positions from the checkpoint
file.close() if writeState:
with open(file.name, 'rb') as f: self.simulation.loadState(filename)
self.simulation.context.loadCheckpoint(f.read()) else:
os.unlink(file.name) self.simulation.loadCheckpoint(filename)
newPositions = self.simulation.context.getState(getPositions=True).getPositions() newPositions = self.simulation.context.getState(getPositions=True).getPositions()
self.assertSequenceEqual(positions, newPositions) self.assertSequenceEqual(positions, newPositions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment