"platforms/cuda/src/CudaArray.cpp" did not exist on "d9e73e43cb2685dd079fdfbbb535ebed8f7327c3"
Commit 2049b1fe authored by peastman's avatar peastman
Browse files

Merge pull request #1211 from jchodera/deepcopy

Fix force deepcopy bug
parents d4853aeb 1006797b
...@@ -450,13 +450,22 @@ Parameters: ...@@ -450,13 +450,22 @@ Parameters:
%extend OpenMM::Force { %extend OpenMM::Force {
%pythoncode %{ %pythoncode %{
def __getstate__(self):
serializationString = XmlSerializer.serialize(self)
return serializationString
def __setstate__(self, serializationString):
system = XmlSerializer.deserialize(serializationString)
self.this = system.this
def __copy__(self):
copy = self.__class__.__new__(self.__class__)
copy.__init__(self)
return copy
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
return self.__copy__() return self.__copy__()
%} %}
%newobject __copy__;
OpenMM::Force* __copy__() {
return OpenMM::XmlSerializer::clone<OpenMM::Force>(*self);
}
} }
%extend OpenMM::Integrator { %extend OpenMM::Integrator {
...@@ -469,11 +478,12 @@ Parameters: ...@@ -469,11 +478,12 @@ Parameters:
system = XmlSerializer.deserialize(serializationString) system = XmlSerializer.deserialize(serializationString)
self.this = system.this self.this = system.this
def __copy__(self):
copy = self.__class__.__new__(self.__class__)
copy.__init__(self)
return copy
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
return self.__copy__() return self.__copy__()
%} %}
%newobject __copy__;
OpenMM::Integrator* __copy__() {
return OpenMM::XmlSerializer::clone<OpenMM::Integrator>(*self);
}
} }
...@@ -3,13 +3,14 @@ from validateConstraints import * ...@@ -3,13 +3,14 @@ from validateConstraints import *
from simtk.openmm.app import * from simtk.openmm.app import *
from simtk.openmm import * from simtk.openmm import *
from simtk.unit import * from simtk.unit import *
import simtk.openmm
import simtk.openmm.app.element as elem import simtk.openmm.app.element as elem
import simtk.openmm.app.forcefield as forcefield import simtk.openmm.app.forcefield as forcefield
import copy import copy
import pickle import pickle
class TestPickle(unittest.TestCase): class TestPickle(unittest.TestCase):
"""Pickling / deepcopy of OpenMM state and integrator objects.""" """Pickling / deepcopy of OpenMM objects."""
def setUp(self): def setUp(self):
"""Set up the tests by loading the input pdb files and force field """Set up the tests by loading the input pdb files and force field
...@@ -26,28 +27,46 @@ class TestPickle(unittest.TestCase): ...@@ -26,28 +27,46 @@ class TestPickle(unittest.TestCase):
self.pdb2 = PDBFile('systems/alanine-dipeptide-implicit.pdb') self.pdb2 = PDBFile('systems/alanine-dipeptide-implicit.pdb')
self.forcefield2 = ForceField('amber99sb.xml', 'amber99_obc.xml') self.forcefield2 = ForceField('amber99sb.xml', 'amber99_obc.xml')
def check_copy(self, object, object_copy):
"""Check that an object's copy is an accurate replica."""
# Check class name is same.
self.assertEqual(object.__class__.__name__, object_copy.__class__.__name__)
# Check serialized contents are the same.
self.assertEqual(XmlSerializer.serialize(object), XmlSerializer.serialize(object_copy))
def test_deepcopy(self): def test_deepcopy(self):
"""Test that serialization/deserialization works (via deepcopy).""" """Test that serialization/deserialization works (via deepcopy)."""
# Create system, integrator, and state.
system = self.forcefield1.createSystem(self.pdb1.topology) system = self.forcefield1.createSystem(self.pdb1.topology)
integrator = VerletIntegrator(2*femtosecond) integrator = VerletIntegrator(2*femtosecond)
context = Context(system, integrator) context = Context(system, integrator)
context.setPositions(self.pdb1.positions) context.setPositions(self.pdb1.positions)
state = context.getState(getPositions=True, getForces=True, getEnergy=True) state = context.getState(getPositions=True, getForces=True, getEnergy=True)
system2 = copy.deepcopy(system) #
integrator2 = copy.deepcopy(integrator) # Test deepcopy
state2 = copy.deepcopy(state) #
str_state = pickle.dumps(state)
str_integrator = pickle.dumps(integrator)
state3 = pickle.loads(str_state)
context.setState(state3)
self.check_copy(system, copy.deepcopy(system))
self.check_copy(integrator, copy.deepcopy(integrator))
self.check_copy(state, copy.deepcopy(state))
for force_index in range(system.getNumForces()):
force = system.getForce(force_index)
force_copy = copy.deepcopy(force)
self.check_copy(force, force_copy)
del context, integrator #
# Test pickle
#
self.check_copy(system, pickle.loads(pickle.dumps(system)))
self.check_copy(integrator, pickle.loads(pickle.dumps(integrator)))
self.check_copy(state, pickle.loads(pickle.dumps(state)))
for force_index in range(system.getNumForces()):
force = system.getForce(force_index)
force_copy = pickle.loads(pickle.dumps(force))
self.check_copy(force, force_copy)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment