Unverified Commit 135a0f9e authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Make WeightEnum and Weights public + cleanups (#7100)

parent cb8c4417
...@@ -7,7 +7,7 @@ import test_models as TM ...@@ -7,7 +7,7 @@ import test_models as TM
import torch import torch
from common_extended_utils import get_file_size_mb, get_ops from common_extended_utils import get_file_size_mb, get_ops
from torchvision import models from torchvision import models
from torchvision.models._api import get_model_weights, Weights, WeightsEnum from torchvision.models import get_model_weights, Weights, WeightsEnum
from torchvision.models._utils import handle_legacy_interface from torchvision.models._utils import handle_legacy_interface
run_if_test_with_extended = pytest.mark.skipif( run_if_test_with_extended = pytest.mark.skipif(
......
...@@ -15,4 +15,9 @@ from .vision_transformer import * ...@@ -15,4 +15,9 @@ from .vision_transformer import *
from .swin_transformer import * from .swin_transformer import *
from .maxvit import * from .maxvit import *
from . import detection, optical_flow, quantization, segmentation, video from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models
# The Weights and WeightsEnum are developer-facing utils that we make public for
# downstream libs like torchgeo https://github.com/pytorch/vision/issues/7094
# TODO: we could / should document them publicly, but it's not clear where, as
# they're not intended for end users.
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models, Weights, WeightsEnum
import importlib import importlib
import inspect import inspect
import sys import sys
from dataclasses import dataclass, fields from dataclasses import dataclass
from enum import Enum
from functools import partial from functools import partial
from inspect import signature from inspect import signature
from types import ModuleType from types import ModuleType
...@@ -9,8 +10,6 @@ from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, ...@@ -9,8 +10,6 @@ from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar,
from torch import nn from torch import nn
from torchvision._utils import StrEnum
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
...@@ -65,7 +64,7 @@ class Weights: ...@@ -65,7 +64,7 @@ class Weights:
return self.transforms == other.transforms return self.transforms == other.transforms
class WeightsEnum(StrEnum): class WeightsEnum(Enum):
""" """
This class is the parent class of all model weights. Each model building method receives an optional `weights` This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
...@@ -75,14 +74,11 @@ class WeightsEnum(StrEnum): ...@@ -75,14 +74,11 @@ class WeightsEnum(StrEnum):
value (Weights): The data class entry with the weight information. value (Weights): The data class entry with the weight information.
""" """
def __init__(self, value: Weights):
self._value_ = value
@classmethod @classmethod
def verify(cls, obj: Any) -> Any: def verify(cls, obj: Any) -> Any:
if obj is not None: if obj is not None:
if type(obj) is str: if type(obj) is str:
obj = cls.from_str(obj.replace(cls.__name__ + ".", "")) obj = cls[obj.replace(cls.__name__ + ".", "")]
elif not isinstance(obj, cls): elif not isinstance(obj, cls):
raise TypeError( raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}." f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
...@@ -95,12 +91,17 @@ class WeightsEnum(StrEnum): ...@@ -95,12 +91,17 @@ class WeightsEnum(StrEnum):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self._name_}" return f"{self.__class__.__name__}.{self._name_}"
def __getattr__(self, name): @property
# Be able to fetch Weights attributes directly def url(self):
for f in fields(Weights): return self.value.url
if f.name == name:
return object.__getattribute__(self.value, name) @property
return super().__getattr__(name) def transforms(self):
return self.value.transforms
@property
def meta(self):
return self.value.meta
def get_weight(name: str) -> WeightsEnum: def get_weight(name: str) -> WeightsEnum:
...@@ -134,7 +135,7 @@ def get_weight(name: str) -> WeightsEnum: ...@@ -134,7 +135,7 @@ def get_weight(name: str) -> WeightsEnum:
if weights_enum is None: if weights_enum is None:
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.") raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
return weights_enum.from_str(value_name) return weights_enum[value_name]
def get_model_weights(name: Union[Callable, str]) -> WeightsEnum: def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment