Commit d4eb5da4 authored by peastman's avatar peastman
Browse files

Implemented runForClockTime()

parent 68664c2e
......@@ -33,6 +33,8 @@ __version__ = "1.0"
import simtk.openmm as mm
import simtk.unit as unit
import sys
from datetime import datetime, timedelta
class Simulation(object):
"""Simulation provides a simplified API for running simulations with OpenMM and reporting results.
......@@ -90,10 +92,51 @@ class Simulation(object):
def step(self, steps):
"""Advance the simulation by integrating a specified number of time steps."""
stepTo = self.currentStep+steps
self._simulate(endStep=self.currentStep+steps)
def runForClockTime(self, time, checkpointFile=None, stateFile=None, checkpointInterval=None):
"""Advance the simulation by integrating time steps until a fixed amount of clock time has elapsed.
This is useful when you have a limited amount of computer time available, and want to run the longest simulation
possible in that time. This method will continue taking time steps until the specified clock time has elapsed,
then return. It also can automatically write out a checkpoint and/or state file before returning, so you can
later resume the simulation. Another option allows it to write checkpoints or states at regular intervals, so
you can resume even if the simulation is interrupted before the time limit is reached.
Parameters:
- time (time) the amount of time to run for. If no units are specified, it is assumed to be a number of hours.
- checkpointFile (string or file=None) if specified, a checkpoint file will be written at the end of the
simulation (and optionally at regular intervals before then) by passing this to saveCheckpoint().
- stateFile (string or file=None) if specified, a state file will be written at the end of the
simulation (and optionally at regular intervals before then) by passing this to saveState().
- checkpointInterval (time=None) if specified, checkpoints and/or states will be written at regular intervals
during the simulation, in addition to writing a final version at the end. If no units are specified, this is
assumed to be in hours.
"""
if unit.is_quantity(time):
time = time.value_in_unit(unit.hours)
if unit.is_quantity(checkpointInterval):
checkpointInterval = checkpointInterval.value_in_unit(unit.hours)
endTime = datetime.now()+timedelta(hours=time)
while (datetime.now() < endTime):
if checkpointInterval is None:
nextTime = endTime
else:
nextTime = datetime.now()+timedelta(hours=checkpointInterval)
if nextTime > endTime:
nextTime = endTime
self._simulate(endTime=nextTime)
if checkpointFile is not None:
self.saveCheckpoint(checkpointFile)
if stateFile is not None:
self.saveState(stateFile)
def _simulate(self, endStep=None, endTime=None):
if endStep is None:
endStep = sys.maxint
nextReport = [None]*len(self.reporters)
while self.currentStep < stepTo:
nextSteps = stepTo-self.currentStep
while self.currentStep < endStep:
nextSteps = endStep-self.currentStep
anyReport = False
for i, reporter in enumerate(self.reporters):
nextReport[i] = reporter.describeNextReport(self)
......@@ -104,6 +147,8 @@ class Simulation(object):
while stepsToGo > 10:
self.integrator.step(10) # Only take 10 steps at a time, to give Python more chances to respond to a control-c.
stepsToGo -= 10
if endTime is not None and datetime.now() >= endTime:
return
self.integrator.step(stepsToGo)
self.currentStep += nextSteps
if anyReport:
......
import unittest
import tempfile
from datetime import datetime, timedelta
from simtk.openmm import *
from simtk.openmm.app import *
from simtk.unit import *
......@@ -74,6 +75,65 @@ class TestSimulation(unittest.TestCase):
self.assertEqual(initialState.getPositions(), state.getPositions())
self.assertEqual(initialState.getVelocities(), state.getVelocities())
def testStep(self):
"""Test the step() method."""
pdb = PDBFile('systems/alanine-dipeptide-implicit.pdb')
ff = ForceField('amber99sb.xml', 'tip3p.xml')
system = ff.createSystem(pdb.topology)
integrator = VerletIntegrator(0.001*picoseconds)
# Create a Simulation.
simulation = Simulation(pdb.topology, system, integrator, Platform.getPlatformByName('Reference'))
simulation.context.setPositions(pdb.positions)
simulation.context.setVelocitiesToTemperature(300*kelvin)
self.assertEqual(0, simulation.currentStep)
self.assertEqual(0*picoseconds, simulation.context.getState().getTime())
# Take some steps and verify the simulation has advanced by the correct amount.
simulation.step(23)
self.assertEqual(23, simulation.currentStep)
self.assertAlmostEqual(0.023, simulation.context.getState().getTime().value_in_unit(picoseconds))
def testRunForClockTime(self):
"""Test the runForClockTime() method."""
pdb = PDBFile('systems/alanine-dipeptide-implicit.pdb')
ff = ForceField('amber99sb.xml', 'tip3p.xml')
system = ff.createSystem(pdb.topology)
integrator = VerletIntegrator(0.001*picoseconds)
# Create a Simulation.
simulation = Simulation(pdb.topology, system, integrator, Platform.getPlatformByName('Reference'))
simulation.context.setPositions(pdb.positions)
simulation.context.setVelocitiesToTemperature(300*kelvin)
self.assertEqual(0, simulation.currentStep)
self.assertEqual(0*picoseconds, simulation.context.getState().getTime())
# Run for five seconds, the save both a checkpoint and a state.
checkpointFile = tempfile.mktemp()
stateFile = tempfile.mktemp()
startTime = datetime.now()
simulation.runForClockTime(5*seconds, checkpointFile=checkpointFile, stateFile=stateFile)
endTime = datetime.now()
# Make sure at least five seconds have elapsed, but no more than ten.
self.assertTrue(endTime >= startTime+timedelta(seconds=5))
self.assertTrue(endTime < startTime+timedelta(seconds=10))
# Load the checkpoint and state and make sure they are both correct.
velocities = simulation.context.getState(getVelocities=True).getVelocities()
simulation.context.setVelocitiesToTemperature(300*kelvin)
simulation.loadCheckpoint(checkpointFile)
self.assertEqual(velocities, simulation.context.getState(getVelocities=True).getVelocities())
simulation.context.setVelocitiesToTemperature(300*kelvin)
simulation.loadState(stateFile)
self.assertEqual(velocities, simulation.context.getState(getVelocities=True).getVelocities())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment