Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
openmm
Commits
b2c35a8b
Unverified
Commit
b2c35a8b
authored
Oct 01, 2021
by
Peter Eastman
Committed by
GitHub
Oct 01, 2021
Browse files
copy() works better for Python subclasses of C++ classes (#3263)
parent
f68adccc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
88 additions
and
15 deletions
+88
-15
wrappers/python/src/swig_doxygen/swigInputBuilder.py
wrappers/python/src/swig_doxygen/swigInputBuilder.py
+3
-3
wrappers/python/src/swig_doxygen/swig_lib/python/extend.i
wrappers/python/src/swig_doxygen/swig_lib/python/extend.i
+61
-12
wrappers/python/tests/TestPickle.py
wrappers/python/tests/TestPickle.py
+24
-0
No files found.
wrappers/python/src/swig_doxygen/swigInputBuilder.py
View file @
b2c35a8b
...
@@ -270,7 +270,7 @@ class SwigInputBuilder:
...
@@ -270,7 +270,7 @@ class SwigInputBuilder:
self
.
fOut
.
write
(
",
\n
OpenMM::%s"
%
name
)
self
.
fOut
.
write
(
",
\n
OpenMM::%s"
%
name
)
self
.
fOut
.
write
(
");
\n\n
"
)
self
.
fOut
.
write
(
");
\n\n
"
)
self
.
fOut
.
write
(
"%factory(OpenMM::Force* OpenMM
::Force::__copy__
"
)
self
.
fOut
.
write
(
"%factory(OpenMM::Force* OpenMM
_XmlSerializer__cloneForce
"
)
for
name
in
sorted
(
forceSubclassList
):
for
name
in
sorted
(
forceSubclassList
):
self
.
fOut
.
write
(
",
\n
OpenMM::%s"
%
name
)
self
.
fOut
.
write
(
",
\n
OpenMM::%s"
%
name
)
self
.
fOut
.
write
(
");
\n\n
"
)
self
.
fOut
.
write
(
");
\n\n
"
)
...
@@ -285,7 +285,7 @@ class SwigInputBuilder:
...
@@ -285,7 +285,7 @@ class SwigInputBuilder:
self
.
fOut
.
write
(
",
\n
OpenMM::%s"
%
name
)
self
.
fOut
.
write
(
",
\n
OpenMM::%s"
%
name
)
self
.
fOut
.
write
(
");
\n\n
"
)
self
.
fOut
.
write
(
");
\n\n
"
)
self
.
fOut
.
write
(
"%factory(OpenMM::Integrator* OpenMM
::Integrator::__copy__
"
)
self
.
fOut
.
write
(
"%factory(OpenMM::Integrator* OpenMM
_XmlSerializer__cloneIntegrator
"
)
for
name
in
sorted
(
integratorSubclassList
):
for
name
in
sorted
(
integratorSubclassList
):
self
.
fOut
.
write
(
",
\n
OpenMM::%s"
%
name
)
self
.
fOut
.
write
(
",
\n
OpenMM::%s"
%
name
)
self
.
fOut
.
write
(
");
\n\n
"
)
self
.
fOut
.
write
(
");
\n\n
"
)
...
@@ -305,7 +305,7 @@ class SwigInputBuilder:
...
@@ -305,7 +305,7 @@ class SwigInputBuilder:
self
.
fOut
.
write
(
",
\n
OpenMM::%s"
%
name
)
self
.
fOut
.
write
(
",
\n
OpenMM::%s"
%
name
)
self
.
fOut
.
write
(
");
\n\n
"
)
self
.
fOut
.
write
(
");
\n\n
"
)
self
.
fOut
.
write
(
"%factory(OpenMM::TabulatedFunction* OpenMM
::
TabulatedFunction
::__copy__
"
)
self
.
fOut
.
write
(
"%factory(OpenMM::TabulatedFunction* OpenMM
_XmlSerializer__clone
TabulatedFunction"
)
for
name
in
sorted
(
tabulatedFunctionSubclassList
):
for
name
in
sorted
(
tabulatedFunctionSubclassList
):
self
.
fOut
.
write
(
",
\n
OpenMM::%s"
%
name
)
self
.
fOut
.
write
(
",
\n
OpenMM::%s"
%
name
)
self
.
fOut
.
write
(
");
\n\n
"
)
self
.
fOut
.
write
(
");
\n\n
"
)
...
...
wrappers/python/src/swig_doxygen/swig_lib/python/extend.i
View file @
b2c35a8b
...
@@ -274,6 +274,11 @@ Parameters:
...
@@ -274,6 +274,11 @@ Parameters:
return
OpenMM
::
XmlSerializer
::
deserialize
<
OpenMM
::
System
>
(
ss
)
;
return
OpenMM
::
XmlSerializer
::
deserialize
<
OpenMM
::
System
>
(
ss
)
;
}
}
%
newobject
_cloneSystem
;
static
OpenMM
::
System
*
_cloneSystem
(
const
OpenMM
::
System
*
object
)
{
return
OpenMM
::
XmlSerializer
::
clone
<
OpenMM
::
System
>
(
*
object
)
;
}
static
std
::
string
_serializeForce
(
const
OpenMM
::
Force
*
object
)
{
static
std
::
string
_serializeForce
(
const
OpenMM
::
Force
*
object
)
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
OpenMM
::
XmlSerializer
::
serialize
<
OpenMM
::
Force
>
(
object
,
"Force"
,
ss
)
;
OpenMM
::
XmlSerializer
::
serialize
<
OpenMM
::
Force
>
(
object
,
"Force"
,
ss
)
;
...
@@ -287,6 +292,11 @@ Parameters:
...
@@ -287,6 +292,11 @@ Parameters:
return
OpenMM
::
XmlSerializer
::
deserialize
<
OpenMM
::
Force
>
(
ss
)
;
return
OpenMM
::
XmlSerializer
::
deserialize
<
OpenMM
::
Force
>
(
ss
)
;
}
}
%
newobject
_cloneForce
;
static
OpenMM
::
Force
*
_cloneForce
(
const
OpenMM
::
Force
*
object
)
{
return
OpenMM
::
XmlSerializer
::
clone
<
OpenMM
::
Force
>
(
*
object
)
;
}
static
std
::
string
_serializeIntegrator
(
const
OpenMM
::
Integrator
*
object
)
{
static
std
::
string
_serializeIntegrator
(
const
OpenMM
::
Integrator
*
object
)
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
OpenMM
::
XmlSerializer
::
serialize
<
OpenMM
::
Integrator
>
(
object
,
"Integrator"
,
ss
)
;
OpenMM
::
XmlSerializer
::
serialize
<
OpenMM
::
Integrator
>
(
object
,
"Integrator"
,
ss
)
;
...
@@ -300,6 +310,11 @@ Parameters:
...
@@ -300,6 +310,11 @@ Parameters:
return
OpenMM
::
XmlSerializer
::
deserialize
<
OpenMM
::
Integrator
>
(
ss
)
;
return
OpenMM
::
XmlSerializer
::
deserialize
<
OpenMM
::
Integrator
>
(
ss
)
;
}
}
%
newobject
_cloneIntegrator
;
static
OpenMM
::
Integrator
*
_cloneIntegrator
(
const
OpenMM
::
Integrator
*
object
)
{
return
OpenMM
::
XmlSerializer
::
clone
<
OpenMM
::
Integrator
>
(
*
object
)
;
}
static
std
::
string
_serializeTabulatedFunction
(
const
OpenMM
::
TabulatedFunction
*
object
)
{
static
std
::
string
_serializeTabulatedFunction
(
const
OpenMM
::
TabulatedFunction
*
object
)
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
OpenMM
::
XmlSerializer
::
serialize
<
OpenMM
::
TabulatedFunction
>
(
object
,
"TabulatedFunction"
,
ss
)
;
OpenMM
::
XmlSerializer
::
serialize
<
OpenMM
::
TabulatedFunction
>
(
object
,
"TabulatedFunction"
,
ss
)
;
...
@@ -313,6 +328,11 @@ Parameters:
...
@@ -313,6 +328,11 @@ Parameters:
return
OpenMM
::
XmlSerializer
::
deserialize
<
OpenMM
::
TabulatedFunction
>
(
ss
)
;
return
OpenMM
::
XmlSerializer
::
deserialize
<
OpenMM
::
TabulatedFunction
>
(
ss
)
;
}
}
%
newobject
_cloneTabulatedFunction
;
static
OpenMM
::
TabulatedFunction
*
_cloneTabulatedFunction
(
const
OpenMM
::
TabulatedFunction
*
object
)
{
return
OpenMM
::
XmlSerializer
::
clone
<
OpenMM
::
TabulatedFunction
>
(
*
object
)
;
}
static
std
::
string
_serializeState
(
const
OpenMM
::
State
*
object
)
{
static
std
::
string
_serializeState
(
const
OpenMM
::
State
*
object
)
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
OpenMM
::
XmlSerializer
::
serialize
<
OpenMM
::
State
>
(
object
,
"State"
,
ss
)
;
OpenMM
::
XmlSerializer
::
serialize
<
OpenMM
::
State
>
(
object
,
"State"
,
ss
)
;
...
@@ -326,6 +346,11 @@ Parameters:
...
@@ -326,6 +346,11 @@ Parameters:
return
OpenMM
::
XmlSerializer
::
deserialize
<
OpenMM
::
State
>
(
ss
)
;
return
OpenMM
::
XmlSerializer
::
deserialize
<
OpenMM
::
State
>
(
ss
)
;
}
}
%
newobject
_cloneState
;
static
OpenMM
::
State
*
_cloneState
(
const
OpenMM
::
State
*
object
)
{
return
OpenMM
::
XmlSerializer
::
clone
<
OpenMM
::
State
>
(
*
object
)
;
}
%
pythoncode
%
{
%
pythoncode
%
{
@
staticmethod
@
staticmethod
def
serialize
(
object
)
:
def
serialize
(
object
)
:
...
@@ -361,6 +386,23 @@ Parameters:
...
@@ -361,6 +386,23 @@ Parameters:
if
type
==
"TabulatedFunction"
:
if
type
==
"TabulatedFunction"
:
return
XmlSerializer
.
_deserializeTabulatedFunction
(
inputString
)
return
XmlSerializer
.
_deserializeTabulatedFunction
(
inputString
)
raise
ValueError
(
"Unsupported object type"
)
raise
ValueError
(
"Unsupported object type"
)
@
staticmethod
def
clone
(
object
)
:
"""Clone an object by first serializing it, then deserializing it again. This method constructs the
new object directly from the SerializationNodes without first converting them to XML. This means
it is faster and uses less memory than making separate calls to serialize() and deserialize()."""
if
isinstance
(
object
,
System
)
:
return
XmlSerializer
.
_cloneSystem
(
object
)
elif
isinstance
(
object
,
Force
)
:
return
XmlSerializer
.
_cloneForce
(
object
)
elif
isinstance
(
object
,
Integrator
)
:
return
XmlSerializer
.
_cloneIntegrator
(
object
)
elif
isinstance
(
object
,
State
)
:
return
XmlSerializer
.
_cloneState
(
object
)
elif
isinstance
(
object
,
TabulatedFunction
)
:
return
XmlSerializer
.
_cloneTabulatedFunction
(
object
)
raise
ValueError
(
"Unsupported object type"
)
%
}
%
}
}
}
...
@@ -384,11 +426,15 @@ Parameters:
...
@@ -384,11 +426,15 @@ Parameters:
def
__deepcopy__
(
self
,
memo
)
:
def
__deepcopy__
(
self
,
memo
)
:
return
self
.
__copy__
()
return
self
.
__copy__
()
def
__copy__
(
self
)
:
duplicate
=
XmlSerializer
.
clone
(
self
)
duplicate
.
__class__
=
self
.
__class__
attributes
=
{key
:
value
for
key
,
value
in
self
.
__dict__
.
items
()
if
key
!=
'
this
'
}
from
copy
import
deepcopy
duplicate
.
__dict__
.
update
(
deepcopy
(
attributes
))
return
duplicate
%
}
%
}
%
newobject
__copy__
;
OpenMM
::
Force
*
__copy__
()
{
return
OpenMM
::
XmlSerializer
::
clone
<
OpenMM
::
Force
>
(
*
self
)
;
}
}
}
%
extend
OpenMM
::
Integrator
{
%
extend
OpenMM
::
Integrator
{
...
@@ -403,11 +449,15 @@ Parameters:
...
@@ -403,11 +449,15 @@ Parameters:
def
__deepcopy__
(
self
,
memo
)
:
def
__deepcopy__
(
self
,
memo
)
:
return
self
.
__copy__
()
return
self
.
__copy__
()
def
__copy__
(
self
)
:
duplicate
=
XmlSerializer
.
clone
(
self
)
duplicate
.
__class__
=
self
.
__class__
attributes
=
{key
:
value
for
key
,
value
in
self
.
__dict__
.
items
()
if
key
!=
'
this
'
}
from
copy
import
deepcopy
duplicate
.
__dict__
.
update
(
deepcopy
(
attributes
))
return
duplicate
%
}
%
}
%
newobject
__copy__
;
OpenMM
::
Integrator
*
__copy__
()
{
return
OpenMM
::
XmlSerializer
::
clone
<
OpenMM
::
Integrator
>
(
*
self
)
;
}
}
}
%
extend
OpenMM
::
TabulatedFunction
{
%
extend
OpenMM
::
TabulatedFunction
{
...
@@ -422,11 +472,10 @@ Parameters:
...
@@ -422,11 +472,10 @@ Parameters:
def
__deepcopy__
(
self
,
memo
)
:
def
__deepcopy__
(
self
,
memo
)
:
return
self
.
__copy__
()
return
self
.
__copy__
()
def
__copy__
(
self
)
:
return
XmlSerializer
.
clone
(
self
)
%
}
%
}
%
newobject
__copy__
;
OpenMM
::
TabulatedFunction
*
__copy__
()
{
return
OpenMM
::
XmlSerializer
::
clone
<
OpenMM
::
TabulatedFunction
>
(
*
self
)
;
}
}
}
%
extend
OpenMM
::
State
{
%
extend
OpenMM
::
State
{
...
...
wrappers/python/tests/TestPickle.py
View file @
b2c35a8b
...
@@ -68,6 +68,30 @@ class TestPickle(unittest.TestCase):
...
@@ -68,6 +68,30 @@ class TestPickle(unittest.TestCase):
force_copy
=
pickle
.
loads
(
pickle
.
dumps
(
force
))
force_copy
=
pickle
.
loads
(
pickle
.
dumps
(
force
))
self
.
check_copy
(
force
,
force_copy
)
self
.
check_copy
(
force
,
force_copy
)
def
testCopyIntegrator
(
self
):
"""Test copying a Python object whose class extends Integrator."""
integrator1
=
MTSIntegrator
(
4
*
femtoseconds
,
[(
2
,
1
),
(
1
,
2
),
(
0
,
8
)])
integrator1
.
extraField
=
5
integrator2
=
copy
.
deepcopy
(
integrator1
)
self
.
assertEqual
(
XmlSerializer
.
serialize
(
integrator1
),
XmlSerializer
.
serialize
(
integrator2
))
self
.
assertEqual
(
MTSIntegrator
,
type
(
integrator2
))
self
.
assertEqual
(
5
,
integrator2
.
extraField
)
self
.
assertEqual
(
1
,
integrator2
.
getNumPerDofVariables
())
def
testCopyForce
(
self
):
"""Test copying a Python object whose class extends Force."""
class
ScaledForce
(
CustomNonbondedForce
):
def
__init__
(
self
,
scale
):
super
().
__init__
(
f
'
{
scale
}
*r'
)
self
.
scale
=
scale
f1
=
ScaledForce
(
3
)
f2
=
copy
.
deepcopy
(
f1
)
self
.
assertEqual
(
XmlSerializer
.
serialize
(
f1
),
XmlSerializer
.
serialize
(
f2
))
self
.
assertEqual
(
ScaledForce
,
type
(
f2
))
self
.
assertEqual
(
3
,
f2
.
scale
)
self
.
assertEqual
(
'3*r'
,
f2
.
getEnergyFunction
())
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment