TestCheckpointReporter.py 1.83 KB
Newer Older
1
import os
Robert McGibbon's avatar
Robert McGibbon committed
2
3
import unittest
import tempfile
4
5
6
from openmm import app
import openmm as mm
from openmm import unit
Robert McGibbon's avatar
Robert McGibbon committed
7
8
9
10


class TestCheckpointReporter(unittest.TestCase):
    def setUp(self):
11
12
        with open('systems/alanine-dipeptide-implicit.pdb') as f:
            pdb = app.PDBFile(f)
Robert McGibbon's avatar
Robert McGibbon committed
13
        forcefield = app.ForceField('amber99sbildn.xml')
14
15
        system = forcefield.createSystem(pdb.topology,
            nonbondedMethod=app.CutoffNonPeriodic, nonbondedCutoff=1.0*unit.nanometers,
Robert McGibbon's avatar
Robert McGibbon committed
16
17
18
            constraints=app.HBonds)
        self.simulation = app.Simulation(pdb.topology, system, mm.VerletIntegrator(0.002*unit.picoseconds))
        self.simulation.context.setPositions(pdb.positions)
19

Robert McGibbon's avatar
Robert McGibbon committed
20
    def test_1(self):
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        """Test CheckpointReporter."""
        for writeState in [True, False]:
            with tempfile.TemporaryDirectory() as tempdir:
                filename = os.path.join(tempdir, 'checkpoint')
                self.simulation.reporters.clear()
                self.simulation.reporters.append(app.CheckpointReporter(filename, 1, writeState=writeState))
                self.simulation.step(1)
        
                # get the current positions
                positions = self.simulation.context.getState(getPositions=True).getPositions()
        
                # now set the positions into junk...
                self.simulation.context.setPositions([mm.Vec3(0, 0, 0)] * len(positions))
        
                # then reload the right positions from the checkpoint
                if writeState:
                    self.simulation.loadState(filename)
                else:
                    self.simulation.loadCheckpoint(filename)
        
                newPositions = self.simulation.context.getState(getPositions=True).getPositions()
                self.assertSequenceEqual(positions, newPositions)
43

Robert McGibbon's avatar
Robert McGibbon committed
44
if __name__ == '__main__':
Robert McGibbon's avatar
Robert McGibbon committed
45
    unittest.main()