Unverified Commit c06d52b1 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

properly support deepcopying and serialization of model weights (#7107)

parent 93df9a50
import copy
import os
import pickle
import pytest
import test_models as TM
......@@ -73,10 +74,32 @@ def test_get_model_weights(name, weight):
],
)
def test_weights_copyable(copy_fn, name):
model_weights = models.get_model_weights(name)
for weights in list(model_weights):
copied_weights = copy_fn(weights)
assert copied_weights is weights
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert copy_fn(weights) is weights
@pytest.mark.parametrize(
"name",
[
"resnet50",
"retinanet_resnet50_fpn_v2",
"raft_large",
"quantized_resnet50",
"lraspp_mobilenet_v3_large",
"mvit_v1_b",
],
)
def test_weights_deserializable(name):
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert pickle.loads(pickle.dumps(weights)) is weights
@pytest.mark.parametrize(
......
......@@ -2,6 +2,7 @@ import importlib
import inspect
import sys
from dataclasses import dataclass, fields
from functools import partial
from inspect import signature
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
......@@ -37,6 +38,32 @@ class Weights:
transforms: Callable
meta: Dict[str, Any]
def __eq__(self, other: Any) -> bool:
# We need this custom implementation for correct deep-copy and deserialization behavior.
# TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
# involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
# defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
# for it, the check against the defined members would fail and effectively prevent the weights from being
# deep-copied or deserialized.
# See https://github.com/pytorch/vision/pull/7107 for details.
if not isinstance(other, Weights):
return NotImplemented
if self.url != other.url:
return False
if self.meta != other.meta:
return False
if isinstance(self.transforms, partial) and isinstance(other.transforms, partial):
return (
self.transforms.func == other.transforms.func
and self.transforms.args == other.transforms.args
and self.transforms.keywords == other.transforms.keywords
)
else:
return self.transforms == other.transforms
class WeightsEnum(StrEnum):
"""
......@@ -75,9 +102,6 @@ class WeightsEnum(StrEnum):
return object.__getattribute__(self.value, name)
return super().__getattr__(name)
def __deepcopy__(self, memodict=None):
return self
def get_weight(name: str) -> WeightsEnum:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment