TestCheckpointReporter.py 2.9 KB
Newer Older
1
import os
Robert McGibbon's avatar
Robert McGibbon committed
2
3
import unittest
import tempfile
4
from io import BytesIO, StringIO
5
6
7
from openmm import app
import openmm as mm
from openmm import unit
Robert McGibbon's avatar
Robert McGibbon committed
8
9
10
11


class TestCheckpointReporter(unittest.TestCase):
    def setUp(self):
12
13
        with open('systems/alanine-dipeptide-implicit.pdb') as f:
            pdb = app.PDBFile(f)
Robert McGibbon's avatar
Robert McGibbon committed
14
        forcefield = app.ForceField('amber99sbildn.xml')
15
16
        system = forcefield.createSystem(pdb.topology,
            nonbondedMethod=app.CutoffNonPeriodic, nonbondedCutoff=1.0*unit.nanometers,
Robert McGibbon's avatar
Robert McGibbon committed
17
18
19
            constraints=app.HBonds)
        self.simulation = app.Simulation(pdb.topology, system, mm.VerletIntegrator(0.002*unit.picoseconds))
        self.simulation.context.setPositions(pdb.positions)
20

Robert McGibbon's avatar
Robert McGibbon committed
21
    def test_1(self):
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
        """Test CheckpointReporter."""
        for writeState in [True, False]:
            with tempfile.TemporaryDirectory() as tempdir:
                filename = os.path.join(tempdir, 'checkpoint')
                self.simulation.reporters.clear()
                self.simulation.reporters.append(app.CheckpointReporter(filename, 1, writeState=writeState))
                self.simulation.step(1)
        
                # get the current positions
                positions = self.simulation.context.getState(getPositions=True).getPositions()
        
                # now set the positions into junk...
                self.simulation.context.setPositions([mm.Vec3(0, 0, 0)] * len(positions))
        
                # then reload the right positions from the checkpoint
                if writeState:
                    self.simulation.loadState(filename)
                else:
                    self.simulation.loadCheckpoint(filename)
        
                newPositions = self.simulation.context.getState(getPositions=True).getPositions()
                self.assertSequenceEqual(positions, newPositions)
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    def testFileObj(self):
        """Test writing to a file object.  This should truncate so that only the most recent frame is present in the output."""

        # Test checkpoint saving.

        checkpointBuffer = BytesIO()
        self.simulation.reporters.clear()
        self.simulation.reporters.append(app.CheckpointReporter(checkpointBuffer, 1, writeState=False))
        self.simulation.step(5)
        checkpointData = checkpointBuffer.getvalue()

        checkpointBuffer = BytesIO()
        self.simulation.saveCheckpoint(checkpointBuffer)
        self.assertSequenceEqual(checkpointData, checkpointBuffer.getvalue())

        # Test state saving.

        stateBuffer = StringIO()
        self.simulation.reporters.clear()
        self.simulation.reporters.append(app.CheckpointReporter(stateBuffer, 1, writeState=True))
        self.simulation.step(5)
        stateData = stateBuffer.getvalue()

        stateBuffer = StringIO()
        self.simulation.saveState(stateBuffer)
        self.assertSequenceEqual(stateData, stateBuffer.getvalue())

Robert McGibbon's avatar
Robert McGibbon committed
72
if __name__ == '__main__':
Robert McGibbon's avatar
Robert McGibbon committed
73
    unittest.main()