"docs/vscode:/vscode.git/clone" did not exist on "aed5eb88adeba872cf9859bd5b5bfe10ba77e835"
Unverified Commit 71d2bb0b authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

improve StrEnum (#5512)

* improve StrEnum

* use StrEnum for model weights

* fix test

* migrate StrEnum to main area
parent e6d82f7d
......@@ -126,7 +126,7 @@ class TestSmoke:
(
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
itertools.chain.from_iterable(
fn(color_spaces=["rgb"], dtypes=[torch.float32])
fn(color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32])
for fn in [
make_images,
make_vanilla_tensor_images,
......
......@@ -14,8 +14,6 @@ make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
def make_image(size=None, *, color_space, extra_dims=(), dtype=torch.float32):
size = size or torch.randint(16, 33, (2,)).tolist()
if isinstance(color_space, str):
color_space = features.ColorSpace[color_space]
num_channels = {
features.ColorSpace.GRAYSCALE: 1,
features.ColorSpace.RGB: 3,
......
import enum
class StrEnumMeta(enum.EnumMeta):
auto = enum.auto
def from_str(self, member: str):
try:
return self[member]
except KeyError:
# TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
# soon as it is migrated.
raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
pass
......@@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Tuple, Union, Optional
import torch
from torchvision.prototype.utils._internal import StrEnum
from torchvision._utils import StrEnum
from ._feature import _Feature
......@@ -30,7 +30,7 @@ class BoundingBox(_Feature):
bounding_box = super().__new__(cls, data, dtype=dtype, device=device)
if isinstance(format, str):
format = BoundingBoxFormat[format]
format = BoundingBoxFormat.from_str(format.upper())
bounding_box._metadata.update(dict(format=format, image_size=image_size))
......
......@@ -4,7 +4,7 @@ import warnings
from typing import Any, Optional, Union, Tuple, cast
import torch
from torchvision.prototype.utils._internal import StrEnum
from torchvision._utils import StrEnum
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import draw_bounding_boxes
from torchvision.utils import make_grid
......@@ -14,9 +14,9 @@ from ._feature import _Feature
class ColorSpace(StrEnum):
OTHER = 0
GRAYSCALE = 1
RGB = 3
OTHER = StrEnum.auto()
GRAYSCALE = StrEnum.auto()
RGB = StrEnum.auto()
class Image(_Feature):
......@@ -37,7 +37,7 @@ class Image(_Feature):
if color_space == ColorSpace.OTHER:
warnings.warn("Unable to guess a specific color space. Consider passing it explicitly.")
elif isinstance(color_space, str):
color_space = ColorSpace[color_space]
color_space = ColorSpace.from_str(color_space.upper())
image._metadata.update(dict(color_space=color_space))
......
......@@ -3,9 +3,10 @@ import inspect
import sys
from collections import OrderedDict
from dataclasses import dataclass, fields
from enum import Enum
from typing import Any, Callable, Dict
from torchvision._utils import StrEnum
from ..._internally_replaced_utils import load_state_dict_from_url
......@@ -34,7 +35,7 @@ class Weights:
meta: Dict[str, Any]
class WeightsEnum(Enum):
class WeightsEnum(StrEnum):
"""
This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
......@@ -58,12 +59,6 @@ class WeightsEnum(Enum):
)
return obj
@classmethod
def from_str(cls, value: str) -> "WeightsEnum":
if value in cls.__members__:
return cls.__members__[value]
raise ValueError(f"Invalid value {value} for enum {cls.__name__}.")
def get_state_dict(self, progress: bool) -> OrderedDict:
return load_state_dict_from_url(self.url, progress=progress)
......
import collections.abc
import difflib
import enum
import functools
import inspect
import io
......@@ -31,7 +30,6 @@ import numpy as np
import torch
__all__ = [
"StrEnum",
"sequence_to_str",
"add_suggestion",
"FrozenMapping",
......@@ -45,17 +43,6 @@ __all__ = [
]
class StrEnumMeta(enum.EnumMeta):
auto = enum.auto
def __getitem__(self, item):
return super().__getitem__(item.upper() if isinstance(item, str) else item)
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
pass
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
if not seq:
return ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment