Unverified Commit b2c35a8b authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

copy() works better for Python subclasses of C++ classes (#3263)

parent f68adccc
...@@ -270,7 +270,7 @@ class SwigInputBuilder: ...@@ -270,7 +270,7 @@ class SwigInputBuilder:
self.fOut.write(",\n OpenMM::%s" % name) self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n") self.fOut.write(");\n\n")
self.fOut.write("%factory(OpenMM::Force* OpenMM::Force::__copy__") self.fOut.write("%factory(OpenMM::Force* OpenMM_XmlSerializer__cloneForce")
for name in sorted(forceSubclassList): for name in sorted(forceSubclassList):
self.fOut.write(",\n OpenMM::%s" % name) self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n") self.fOut.write(");\n\n")
...@@ -285,7 +285,7 @@ class SwigInputBuilder: ...@@ -285,7 +285,7 @@ class SwigInputBuilder:
self.fOut.write(",\n OpenMM::%s" % name) self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n") self.fOut.write(");\n\n")
self.fOut.write("%factory(OpenMM::Integrator* OpenMM::Integrator::__copy__") self.fOut.write("%factory(OpenMM::Integrator* OpenMM_XmlSerializer__cloneIntegrator")
for name in sorted(integratorSubclassList): for name in sorted(integratorSubclassList):
self.fOut.write(",\n OpenMM::%s" % name) self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n") self.fOut.write(");\n\n")
...@@ -305,7 +305,7 @@ class SwigInputBuilder: ...@@ -305,7 +305,7 @@ class SwigInputBuilder:
self.fOut.write(",\n OpenMM::%s" % name) self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n") self.fOut.write(");\n\n")
self.fOut.write("%factory(OpenMM::TabulatedFunction* OpenMM::TabulatedFunction::__copy__") self.fOut.write("%factory(OpenMM::TabulatedFunction* OpenMM_XmlSerializer__cloneTabulatedFunction")
for name in sorted(tabulatedFunctionSubclassList): for name in sorted(tabulatedFunctionSubclassList):
self.fOut.write(",\n OpenMM::%s" % name) self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n") self.fOut.write(");\n\n")
......
...@@ -274,6 +274,11 @@ Parameters: ...@@ -274,6 +274,11 @@ Parameters:
return OpenMM::XmlSerializer::deserialize<OpenMM::System>(ss); return OpenMM::XmlSerializer::deserialize<OpenMM::System>(ss);
} }
%newobject _cloneSystem;
static OpenMM::System* _cloneSystem(const OpenMM::System* object) {
return OpenMM::XmlSerializer::clone<OpenMM::System>(*object);
}
static std::string _serializeForce(const OpenMM::Force* object) { static std::string _serializeForce(const OpenMM::Force* object) {
std::stringstream ss; std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::Force>(object, "Force", ss); OpenMM::XmlSerializer::serialize<OpenMM::Force>(object, "Force", ss);
...@@ -287,6 +292,11 @@ Parameters: ...@@ -287,6 +292,11 @@ Parameters:
return OpenMM::XmlSerializer::deserialize<OpenMM::Force>(ss); return OpenMM::XmlSerializer::deserialize<OpenMM::Force>(ss);
} }
%newobject _cloneForce;
static OpenMM::Force* _cloneForce(const OpenMM::Force* object) {
return OpenMM::XmlSerializer::clone<OpenMM::Force>(*object);
}
static std::string _serializeIntegrator(const OpenMM::Integrator* object) { static std::string _serializeIntegrator(const OpenMM::Integrator* object) {
std::stringstream ss; std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::Integrator>(object, "Integrator", ss); OpenMM::XmlSerializer::serialize<OpenMM::Integrator>(object, "Integrator", ss);
...@@ -300,6 +310,11 @@ Parameters: ...@@ -300,6 +310,11 @@ Parameters:
return OpenMM::XmlSerializer::deserialize<OpenMM::Integrator>(ss); return OpenMM::XmlSerializer::deserialize<OpenMM::Integrator>(ss);
} }
%newobject _cloneIntegrator;
static OpenMM::Integrator* _cloneIntegrator(const OpenMM::Integrator* object) {
return OpenMM::XmlSerializer::clone<OpenMM::Integrator>(*object);
}
static std::string _serializeTabulatedFunction(const OpenMM::TabulatedFunction* object) { static std::string _serializeTabulatedFunction(const OpenMM::TabulatedFunction* object) {
std::stringstream ss; std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::TabulatedFunction>(object, "TabulatedFunction", ss); OpenMM::XmlSerializer::serialize<OpenMM::TabulatedFunction>(object, "TabulatedFunction", ss);
...@@ -313,6 +328,11 @@ Parameters: ...@@ -313,6 +328,11 @@ Parameters:
return OpenMM::XmlSerializer::deserialize<OpenMM::TabulatedFunction>(ss); return OpenMM::XmlSerializer::deserialize<OpenMM::TabulatedFunction>(ss);
} }
%newobject _cloneTabulatedFunction;
static OpenMM::TabulatedFunction* _cloneTabulatedFunction(const OpenMM::TabulatedFunction* object) {
return OpenMM::XmlSerializer::clone<OpenMM::TabulatedFunction>(*object);
}
static std::string _serializeState(const OpenMM::State* object) { static std::string _serializeState(const OpenMM::State* object) {
std::stringstream ss; std::stringstream ss;
OpenMM::XmlSerializer::serialize<OpenMM::State>(object, "State", ss); OpenMM::XmlSerializer::serialize<OpenMM::State>(object, "State", ss);
...@@ -326,6 +346,11 @@ Parameters: ...@@ -326,6 +346,11 @@ Parameters:
return OpenMM::XmlSerializer::deserialize<OpenMM::State>(ss); return OpenMM::XmlSerializer::deserialize<OpenMM::State>(ss);
} }
%newobject _cloneState;
static OpenMM::State* _cloneState(const OpenMM::State* object) {
return OpenMM::XmlSerializer::clone<OpenMM::State>(*object);
}
%pythoncode %{ %pythoncode %{
@staticmethod @staticmethod
def serialize(object): def serialize(object):
...@@ -361,6 +386,23 @@ Parameters: ...@@ -361,6 +386,23 @@ Parameters:
if type == "TabulatedFunction": if type == "TabulatedFunction":
return XmlSerializer._deserializeTabulatedFunction(inputString) return XmlSerializer._deserializeTabulatedFunction(inputString)
raise ValueError("Unsupported object type") raise ValueError("Unsupported object type")
@staticmethod
def clone(object):
"""Clone an object by first serializing it, then deserializing it again. This method constructs the
new object directly from the SerializationNodes without first converting them to XML. This means
it is faster and uses less memory than making separate calls to serialize() and deserialize()."""
if isinstance(object, System):
return XmlSerializer._cloneSystem(object)
elif isinstance(object, Force):
return XmlSerializer._cloneForce(object)
elif isinstance(object, Integrator):
return XmlSerializer._cloneIntegrator(object)
elif isinstance(object, State):
return XmlSerializer._cloneState(object)
elif isinstance(object, TabulatedFunction):
return XmlSerializer._cloneTabulatedFunction(object)
raise ValueError("Unsupported object type")
%} %}
} }
...@@ -384,11 +426,15 @@ Parameters: ...@@ -384,11 +426,15 @@ Parameters:
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
return self.__copy__() return self.__copy__()
def __copy__(self):
duplicate = XmlSerializer.clone(self)
duplicate.__class__ = self.__class__
attributes = {key: value for key, value in self.__dict__.items() if key != 'this'}
from copy import deepcopy
duplicate.__dict__.update(deepcopy(attributes))
return duplicate
%} %}
%newobject __copy__;
OpenMM::Force* __copy__() {
return OpenMM::XmlSerializer::clone<OpenMM::Force>(*self);
}
} }
%extend OpenMM::Integrator { %extend OpenMM::Integrator {
...@@ -403,11 +449,15 @@ Parameters: ...@@ -403,11 +449,15 @@ Parameters:
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
return self.__copy__() return self.__copy__()
def __copy__(self):
duplicate = XmlSerializer.clone(self)
duplicate.__class__ = self.__class__
attributes = {key: value for key, value in self.__dict__.items() if key != 'this'}
from copy import deepcopy
duplicate.__dict__.update(deepcopy(attributes))
return duplicate
%} %}
%newobject __copy__;
OpenMM::Integrator* __copy__() {
return OpenMM::XmlSerializer::clone<OpenMM::Integrator>(*self);
}
} }
%extend OpenMM::TabulatedFunction { %extend OpenMM::TabulatedFunction {
...@@ -422,11 +472,10 @@ Parameters: ...@@ -422,11 +472,10 @@ Parameters:
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
return self.__copy__() return self.__copy__()
def __copy__(self):
return XmlSerializer.clone(self)
%} %}
%newobject __copy__;
OpenMM::TabulatedFunction* __copy__() {
return OpenMM::XmlSerializer::clone<OpenMM::TabulatedFunction>(*self);
}
} }
%extend OpenMM::State { %extend OpenMM::State {
......
...@@ -68,6 +68,30 @@ class TestPickle(unittest.TestCase): ...@@ -68,6 +68,30 @@ class TestPickle(unittest.TestCase):
force_copy = pickle.loads(pickle.dumps(force)) force_copy = pickle.loads(pickle.dumps(force))
self.check_copy(force, force_copy) self.check_copy(force, force_copy)
def testCopyIntegrator(self):
"""Test copying a Python object whose class extends Integrator."""
integrator1 = MTSIntegrator(4*femtoseconds, [(2,1), (1,2), (0,8)])
integrator1.extraField = 5
integrator2 = copy.deepcopy(integrator1)
self.assertEqual(XmlSerializer.serialize(integrator1), XmlSerializer.serialize(integrator2))
self.assertEqual(MTSIntegrator, type(integrator2))
self.assertEqual(5, integrator2.extraField)
self.assertEqual(1, integrator2.getNumPerDofVariables())
def testCopyForce(self):
"""Test copying a Python object whose class extends Force."""
class ScaledForce(CustomNonbondedForce):
def __init__(self, scale):
super().__init__(f'{scale}*r')
self.scale = scale
f1 = ScaledForce(3)
f2 = copy.deepcopy(f1)
self.assertEqual(XmlSerializer.serialize(f1), XmlSerializer.serialize(f2))
self.assertEqual(ScaledForce, type(f2))
self.assertEqual(3, f2.scale)
self.assertEqual('3*r', f2.getEnergyFunction())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment