Unverified Commit 71d2bb0b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve StrEnum (#5512)

* improve StrEnum

* use StrEnum for model weights

* fix test

* migrate StrEnum to main area
parent e6d82f7d
...@@ -126,7 +126,7 @@ class TestSmoke: ...@@ -126,7 +126,7 @@ class TestSmoke:
( (
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]), transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
itertools.chain.from_iterable( itertools.chain.from_iterable(
fn(color_spaces=["rgb"], dtypes=[torch.float32]) fn(color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32])
for fn in [ for fn in [
make_images, make_images,
make_vanilla_tensor_images, make_vanilla_tensor_images,
......
...@@ -14,8 +14,6 @@ make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") ...@@ -14,8 +14,6 @@ make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32): def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32):
size = size or torch.randint(16, 33, (2,)).tolist() size = size or torch.randint(16, 33, (2,)).tolist()
if isinstance(color_space, str):
color_space = features.ColorSpace[color_space]
num_channels = { num_channels = {
features.ColorSpace.GRAYSCALE: 1, features.ColorSpace.GRAYSCALE: 1,
features.ColorSpace.RGB: 3, features.ColorSpace.RGB: 3,
......
import enum
class StrEnumMeta(enum.EnumMeta):
auto = enum.auto
def from_str(self, member: str):
try:
return self[member]
except KeyError:
# TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
# soon as it is migrated.
raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
pass
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Tuple, Union, Optional from typing import Any, Tuple, Union, Optional
import torch import torch
from torchvision.prototype.utils._internal import StrEnum from torchvision._utils import StrEnum
from ._feature import _Feature from ._feature import _Feature
...@@ -30,7 +30,7 @@ class BoundingBox(_Feature): ...@@ -30,7 +30,7 @@ class BoundingBox(_Feature):
bounding_box = super().__new__(cls, data, dtype=dtype, device=device) bounding_box = super().__new__(cls, data, dtype=dtype, device=device)
if isinstance(format, str): if isinstance(format, str):
format = BoundingBoxFormat[format] format = BoundingBoxFormat.from_str(format.upper())
bounding_box._metadata.update(dict(format=format, image_size=image_size)) bounding_box._metadata.update(dict(format=format, image_size=image_size))
......
...@@ -4,7 +4,7 @@ import warnings ...@@ -4,7 +4,7 @@ import warnings
from typing import Any, Optional, Union, Tuple, cast from typing import Any, Optional, Union, Tuple, cast
import torch import torch
from torchvision.prototype.utils._internal import StrEnum from torchvision._utils import StrEnum
from torchvision.transforms.functional import to_pil_image from torchvision.transforms.functional import to_pil_image
from torchvision.utils import draw_bounding_boxes from torchvision.utils import draw_bounding_boxes
from torchvision.utils import make_grid from torchvision.utils import make_grid
...@@ -14,9 +14,9 @@ from ._feature import _Feature ...@@ -14,9 +14,9 @@ from ._feature import _Feature
class ColorSpace(StrEnum): class ColorSpace(StrEnum):
OTHER = 0 OTHER = StrEnum.auto()
GRAYSCALE = 1 GRAYSCALE = StrEnum.auto()
RGB = 3 RGB = StrEnum.auto()
class Image(_Feature): class Image(_Feature):
...@@ -37,7 +37,7 @@ class Image(_Feature): ...@@ -37,7 +37,7 @@ class Image(_Feature):
if color_space == ColorSpace.OTHER: if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.") warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str): elif isinstance(color_space, str):
color_space = ColorSpace[color_space] color_space = ColorSpace.from_str(color_space.upper())
image._metadata.update(dict(color_space=color_space)) image._metadata.update(dict(color_space=color_space))
......
...@@ -3,9 +3,10 @@ import inspect ...@@ -3,9 +3,10 @@ import inspect
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
from torchvision._utils import StrEnum
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
...@@ -34,7 +35,7 @@ class Weights: ...@@ -34,7 +35,7 @@ class Weights:
meta: Dict[str, Any] meta: Dict[str, Any]
class WeightsEnum(Enum): class WeightsEnum(StrEnum):
""" """
This class is the parent class of all model weights. Each model building method receives an optional `weights` This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
...@@ -58,12 +59,6 @@ class WeightsEnum(Enum): ...@@ -58,12 +59,6 @@ class WeightsEnum(Enum):
) )
return obj return obj
@classmethod
def from_str(cls, value: str) -> "WeightsEnum":
if value in cls.__members__:
return cls.__members__[value]
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
def get_state_dict(self, progress: bool) -> OrderedDict: def get_state_dict(self, progress: bool) -> OrderedDict:
return load_state_dict_from_url(self.url, progress=progress) return load_state_dict_from_url(self.url, progress=progress)
......
import collections.abc import collections.abc
import difflib import difflib
import enum
import functools import functools
import inspect import inspect
import io import io
...@@ -31,7 +30,6 @@ import numpy as np ...@@ -31,7 +30,6 @@ import numpy as np
import torch import torch
__all__ = [ __all__ = [
"StrEnum",
"sequence_to_str", "sequence_to_str",
"add_suggestion", "add_suggestion",
"FrozenMapping", "FrozenMapping",
...@@ -45,17 +43,6 @@ __all__ = [ ...@@ -45,17 +43,6 @@ __all__ = [
] ]
class StrEnumMeta(enum.EnumMeta):
auto = enum.auto
def __getitem__(self, item):
return super().__getitem__(item.upper() if isinstance(item, str) else item)
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
pass
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq: if not seq:
return "" return ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment