"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9c8eca702c2fa811fba1ccff82a6aee6a04a2556"
Unverified Commit c06d52b1 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

properly support deepcopying and serialization of model weights (#7107)

parent 93df9a50
import copy import copy
import os import os
import pickle
import pytest import pytest
import test_models as TM import test_models as TM
...@@ -73,10 +74,32 @@ def test_get_model_weights(name, weight): ...@@ -73,10 +74,32 @@ def test_get_model_weights(name, weight):
], ],
) )
def test_weights_copyable(copy_fn, name): def test_weights_copyable(copy_fn, name):
model_weights = models.get_model_weights(name) for weights in list(models.get_model_weights(name)):
for weights in list(model_weights): # It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior
copied_weights = copy_fn(weights) # of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
assert copied_weights is weights # Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert copy_fn(weights) is weights
@pytest.mark.parametrize(
"name",
[
"resnet50",
"retinanet_resnet50_fpn_v2",
"raft_large",
"quantized_resnet50",
"lraspp_mobilenet_v3_large",
"mvit_v1_b",
],
)
def test_weights_deserializable(name):
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert pickle.loads(pickle.dumps(weights)) is weights
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -2,6 +2,7 @@ import importlib ...@@ -2,6 +2,7 @@ import importlib
import inspect import inspect
import sys import sys
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import partial
from inspect import signature from inspect import signature
from types import ModuleType from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
...@@ -37,6 +38,32 @@ class Weights: ...@@ -37,6 +38,32 @@ class Weights:
transforms: Callable transforms: Callable
meta: Dict[str, Any] meta: Dict[str, Any]
def __eq__(self, other: Any) -> bool:
# We need this custom implementation for correct deep-copy and deserialization behavior.
# TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
# involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
# defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
# for it, the check against the defined members would fail and effectively prevent the weights from being
# deep-copied or deserialized.
# See https://github.com/pytorch/vision/pull/7107 for details.
if not isinstance(other, Weights):
return NotImplemented
if self.url != other.url:
return False
if self.meta != other.meta:
return False
if isinstance(self.transforms, partial) and isinstance(other.transforms, partial):
return (
self.transforms.func == other.transforms.func
and self.transforms.args == other.transforms.args
and self.transforms.keywords == other.transforms.keywords
)
else:
return self.transforms == other.transforms
class WeightsEnum(StrEnum): class WeightsEnum(StrEnum):
""" """
...@@ -75,9 +102,6 @@ class WeightsEnum(StrEnum): ...@@ -75,9 +102,6 @@ class WeightsEnum(StrEnum):
return object.__getattribute__(self.value, name) return object.__getattribute__(self.value, name)
return super().__getattr__(name) return super().__getattr__(name)
def __deepcopy__(self, memodict=None):
return self
def get_weight(name: str) -> WeightsEnum: def get_weight(name: str) -> WeightsEnum:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment