Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c06d52b1
Unverified
Commit
c06d52b1
authored
Jan 19, 2023
by
Philip Meier
Committed by
GitHub
Jan 19, 2023
Browse files
properly support deepcopying and serialization of model weights (#7107)
parent
93df9a50
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
7 deletions
+54
-7
test/test_extended_models.py
test/test_extended_models.py
+27
-4
torchvision/models/_api.py
torchvision/models/_api.py
+27
-3
No files found.
test/test_extended_models.py
View file @
c06d52b1
import
copy
import
copy
import
os
import
os
import
pickle
import
pytest
import
pytest
import
test_models
as
TM
import
test_models
as
TM
...
@@ -73,10 +74,32 @@ def test_get_model_weights(name, weight):
...
@@ -73,10 +74,32 @@ def test_get_model_weights(name, weight):
],
],
)
)
def
test_weights_copyable
(
copy_fn
,
name
):
def
test_weights_copyable
(
copy_fn
,
name
):
model_weights
=
models
.
get_model_weights
(
name
)
for
weights
in
list
(
models
.
get_model_weights
(
name
)):
for
weights
in
list
(
model_weights
):
# It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior
copied_weights
=
copy_fn
(
weights
)
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
assert
copied_weights
is
weights
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert
copy_fn
(
weights
)
is
weights
@
pytest
.
mark
.
parametrize
(
"name"
,
[
"resnet50"
,
"retinanet_resnet50_fpn_v2"
,
"raft_large"
,
"quantized_resnet50"
,
"lraspp_mobilenet_v3_large"
,
"mvit_v1_b"
,
],
)
def
test_weights_deserializable
(
name
):
for
weights
in
list
(
models
.
get_model_weights
(
name
)):
# It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert
pickle
.
loads
(
pickle
.
dumps
(
weights
))
is
weights
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
torchvision/models/_api.py
View file @
c06d52b1
...
@@ -2,6 +2,7 @@ import importlib
...
@@ -2,6 +2,7 @@ import importlib
import
inspect
import
inspect
import
sys
import
sys
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
functools
import
partial
from
inspect
import
signature
from
inspect
import
signature
from
types
import
ModuleType
from
types
import
ModuleType
from
typing
import
Any
,
Callable
,
cast
,
Dict
,
List
,
Mapping
,
Optional
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
cast
,
Dict
,
List
,
Mapping
,
Optional
,
TypeVar
,
Union
...
@@ -37,6 +38,32 @@ class Weights:
...
@@ -37,6 +38,32 @@ class Weights:
transforms
:
Callable
transforms
:
Callable
meta
:
Dict
[
str
,
Any
]
meta
:
Dict
[
str
,
Any
]
def
__eq__
(
self
,
other
:
Any
)
->
bool
:
# We need this custom implementation for correct deep-copy and deserialization behavior.
# TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
# involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
# defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
# for it, the check against the defined members would fail and effectively prevent the weights from being
# deep-copied or deserialized.
# See https://github.com/pytorch/vision/pull/7107 for details.
if
not
isinstance
(
other
,
Weights
):
return
NotImplemented
if
self
.
url
!=
other
.
url
:
return
False
if
self
.
meta
!=
other
.
meta
:
return
False
if
isinstance
(
self
.
transforms
,
partial
)
and
isinstance
(
other
.
transforms
,
partial
):
return
(
self
.
transforms
.
func
==
other
.
transforms
.
func
and
self
.
transforms
.
args
==
other
.
transforms
.
args
and
self
.
transforms
.
keywords
==
other
.
transforms
.
keywords
)
else
:
return
self
.
transforms
==
other
.
transforms
class
WeightsEnum
(
StrEnum
):
class
WeightsEnum
(
StrEnum
):
"""
"""
...
@@ -75,9 +102,6 @@ class WeightsEnum(StrEnum):
...
@@ -75,9 +102,6 @@ class WeightsEnum(StrEnum):
return
object
.
__getattribute__
(
self
.
value
,
name
)
return
object
.
__getattribute__
(
self
.
value
,
name
)
return
super
().
__getattr__
(
name
)
return
super
().
__getattr__
(
name
)
def
__deepcopy__
(
self
,
memodict
=
None
):
return
self
def
get_weight
(
name
:
str
)
->
WeightsEnum
:
def
get_weight
(
name
:
str
)
->
WeightsEnum
:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment