Commit 2fa4baaa authored by peastman's avatar peastman
Browse files

Added checkpointing functions to Simulation

parent 8f0d1f29
...@@ -6,7 +6,7 @@ Simbios, the NIH National Center for Physics-Based Simulation of ...@@ -6,7 +6,7 @@ Simbios, the NIH National Center for Physics-Based Simulation of
Biological Structures at Stanford, funded under the NIH Roadmap for Biological Structures at Stanford, funded under the NIH Roadmap for
Medical Research, grant U54 GM072970. See https://simtk.org. Medical Research, grant U54 GM072970. See https://simtk.org.
Portions copyright (c) 2012 Stanford University and the Authors. Portions copyright (c) 2012-2015 Stanford University and the Authors.
Authors: Peter Eastman Authors: Peter Eastman
Contributors: Contributors:
...@@ -126,3 +126,71 @@ class Simulation(object): ...@@ -126,3 +126,71 @@ class Simulation(object):
if next[0] == nextSteps: if next[0] == nextSteps:
reporter.report(self, state) reporter.report(self, state)
def saveCheckpoint(self, file):
"""Save a checkpoint of the simulation to a file.
The output is a binary file that contains a complete representation of the current state of the Simulation.
It includes both publicly visible data such as the particle positions and velocities, and also internal data
such as the states of random number generators. Reloading the checkpoint will put the Simulation back into
precisely the same state it had before, so it can be exactly continued.
A checkpoint file is highly specific to the Simulation it was created from. It can only be loaded into
another Simulation that has an identical System, uses the same Platform and OpenMM version, and is running on
identical hardware. If you need a more portable way to resume simulations, consider using saveState() instead.
Parameters:
- file (string or file) a File-like object to write the checkpoint to, or alternatively a filename
"""
if isinstance(file, str):
with open(file, 'wb') as f:
f.write(self.context.createCheckpoint())
else:
file.write(self.context.createCheckpoint())
def loadCheckpoint(self, file):
"""Load a checkpoint file that was created with saveCheckpoint().
Parameters:
- file (string or file) a File-like object to load the checkpoint from, or alternatively a filename
"""
if isinstance(file, str):
with open(file, 'rb') as f:
self.context.loadCheckpoint(f.read())
else:
self.context.loadCheckpoint(file.read())
def saveState(self, file):
"""Save the current state of the simulation to a file.
The output is an XML file containing a serialized State object. It includes all publicly visible data,
including positions, velocities, and parameters. Reloading the State will put the Simulation back into
approximately the same state it had before.
Unlike saveCheckpoint(), this does not store internal data such as the states of random number generators.
Therefore, you should not expect the following trajectory to be identical to what would have been produced
with the original Simulation. On the other hand, this means it is portable across different Platforms or
hardware.
Parameters:
- file (string or file) a File-like object to write the state to, or alternatively a filename
"""
state = self.context.getState(getPositions=True, getVelocities=True, getParameters=True)
xml = mm.XmlSerializer.serialize(state)
if isinstance(file, str):
with open(file, 'w') as f:
f.write(xml)
else:
file.write(xml)
def loadState(self, file):
"""Load a State file that was created with saveState().
Parameters:
- file (string or file) a File-like object to load the state from, or alternatively a filename
"""
if isinstance(file, str):
with open(file, 'r') as f:
xml = f.read()
else:
xml = file.read()
self.context.setState(mm.XmlSerializer.deserialize(xml))
import unittest
import tempfile
from simtk.openmm import *
from simtk.openmm.app import *
from simtk.unit import *
class TestSimulation(unittest.TestCase):
"""Test the Simulation class"""
def testCheckpointing(self):
"""Test that checkpointing works correctly."""
pdb = PDBFile('systems/alanine-dipeptide-implicit.pdb')
ff = ForceField('amber99sb.xml', 'tip3p.xml')
system = ff.createSystem(pdb.topology)
integrator = VerletIntegrator(0.001*picoseconds)
# Create a Simulation.
simulation = Simulation(pdb.topology, system, integrator, Platform.getPlatformByName('Reference'))
simulation.context.setPositions(pdb.positions)
simulation.context.setVelocitiesToTemperature(300*kelvin)
initialState = simulation.context.getState(getPositions=True, getVelocities=True)
# Create a checkpoint.
filename = tempfile.mktemp()
simulation.saveCheckpoint(filename)
# Take a few steps so the positions and velocities will be different.
simulation.step(2)
state = simulation.context.getState(getPositions=True, getVelocities=True)
self.assertNotEqual(initialState.getPositions(), state.getPositions())
self.assertNotEqual(initialState.getVelocities(), state.getVelocities())
# Reload the checkpoint and see if it resets them correctly.
simulation.loadCheckpoint(filename)
state = simulation.context.getState(getPositions=True, getVelocities=True)
self.assertEqual(initialState.getPositions(), state.getPositions())
self.assertEqual(initialState.getVelocities(), state.getVelocities())
def testSaveState(self):
"""Test that saving States works correctly."""
pdb = PDBFile('systems/alanine-dipeptide-implicit.pdb')
ff = ForceField('amber99sb.xml', 'tip3p.xml')
system = ff.createSystem(pdb.topology)
integrator = VerletIntegrator(0.001*picoseconds)
# Create a Simulation.
simulation = Simulation(pdb.topology, system, integrator, Platform.getPlatformByName('Reference'))
simulation.context.setPositions(pdb.positions)
simulation.context.setVelocitiesToTemperature(300*kelvin)
initialState = simulation.context.getState(getPositions=True, getVelocities=True)
# Create a state.
filename = tempfile.mktemp()
simulation.saveState(filename)
# Take a few steps so the positions and velocities will be different.
simulation.step(2)
state = simulation.context.getState(getPositions=True, getVelocities=True)
self.assertNotEqual(initialState.getPositions(), state.getPositions())
self.assertNotEqual(initialState.getVelocities(), state.getVelocities())
# Reload the state and see if it resets them correctly.
simulation.loadState(filename)
state = simulation.context.getState(getPositions=True, getVelocities=True)
self.assertEqual(initialState.getPositions(), state.getPositions())
self.assertEqual(initialState.getVelocities(), state.getVelocities())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment