TestPickle.py 3.72 KB
Newer Older
1
2
import unittest
from validateConstraints import *
3
4
5
6
7
8
from openmm.app import *
from openmm import *
from openmm.unit import *
import openmm
import openmm.app.element as elem
import openmm.app.forcefield as forcefield
9
10
11
import copy
import pickle

12
class TestPickle(unittest.TestCase):
13
    """Pickling / deepcopy of OpenMM objects."""
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

    def setUp(self):
        """Set up the tests by loading the input pdb files and force field
        xml files.

        """
        # alanine dipeptide with explicit water
        self.pdb1 = PDBFile('systems/alanine-dipeptide-explicit.pdb')
        self.forcefield1 = ForceField('amber99sb.xml', 'tip3p.xml')
        self.topology1 = self.pdb1.topology
        self.topology1.setUnitCellDimensions(Vec3(2, 2, 2))

        # alalnine dipeptide with implicit water
        self.pdb2 = PDBFile('systems/alanine-dipeptide-implicit.pdb')
        self.forcefield2 = ForceField('amber99sb.xml', 'amber99_obc.xml')

30
31
    def check_copy(self, object, object_copy):
        """Check that an object's copy is an accurate replica."""
John Chodera (MSKCC)'s avatar
John Chodera (MSKCC) committed
32
        # Check class name is same.
33
34
35
        self.assertEqual(object.__class__.__name__, object_copy.__class__.__name__)
        # Check serialized contents are the same.
        self.assertEqual(XmlSerializer.serialize(object), XmlSerializer.serialize(object_copy))
John Chodera (MSKCC)'s avatar
John Chodera (MSKCC) committed
36
37
38

    def test_deepcopy(self):
        """Test that serialization/deserialization works (via deepcopy)."""
39

40
        # Create system, integrator, and state.
41
42
43
44
45
46
        system = self.forcefield1.createSystem(self.pdb1.topology)
        integrator = VerletIntegrator(2*femtosecond)
        context = Context(system, integrator)
        context.setPositions(self.pdb1.positions)
        state = context.getState(getPositions=True, getForces=True, getEnergy=True)

47
48
49
        #
        # Test deepcopy
        #
50

51
52
53
54
55
56
57
        self.check_copy(system, copy.deepcopy(system))
        self.check_copy(integrator, copy.deepcopy(integrator))
        self.check_copy(state, copy.deepcopy(state))
        for force_index in range(system.getNumForces()):
            force = system.getForce(force_index)
            force_copy = copy.deepcopy(force)
            self.check_copy(force, force_copy)
58

59
60
61
        #
        # Test pickle
        #
62

63
64
65
66
67
68
69
        self.check_copy(system, pickle.loads(pickle.dumps(system)))
        self.check_copy(integrator, pickle.loads(pickle.dumps(integrator)))
        self.check_copy(state, pickle.loads(pickle.dumps(state)))
        for force_index in range(system.getNumForces()):
            force = system.getForce(force_index)
            force_copy = pickle.loads(pickle.dumps(force))
            self.check_copy(force, force_copy)
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    def testCopyIntegrator(self):
        """Test copying a Python object whose class extends Integrator."""
        integrator1 = MTSIntegrator(4*femtoseconds, [(2,1), (1,2), (0,8)])
        integrator1.extraField = 5
        integrator2 = copy.deepcopy(integrator1)
        self.assertEqual(XmlSerializer.serialize(integrator1), XmlSerializer.serialize(integrator2))
        self.assertEqual(MTSIntegrator, type(integrator2))
        self.assertEqual(5, integrator2.extraField)
        self.assertEqual(1, integrator2.getNumPerDofVariables())

    def testCopyForce(self):
        """Test copying a Python object whose class extends Force."""
        class ScaledForce(CustomNonbondedForce):
            def __init__(self, scale):
                super().__init__(f'{scale}*r')
                self.scale = scale

        f1 = ScaledForce(3)
        f2 = copy.deepcopy(f1)
        self.assertEqual(XmlSerializer.serialize(f1), XmlSerializer.serialize(f2))
        self.assertEqual(ScaledForce, type(f2))
        self.assertEqual(3, f2.scale)
        self.assertEqual('3*r', f2.getEnergyFunction())

95
96
97
if __name__ == '__main__':
    unittest.main()