Unverified Commit 449cc090 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

remove strEnum from BoundingBoxFormat (#7322)

parent 7fefdea3
......@@ -28,5 +28,5 @@ def test_bbox_instance(data, format):
assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4
if isinstance(format, str):
format = datapoints.BoundingBoxFormat.from_str(format.upper())
format = datapoints.BoundingBoxFormat[(format.upper())]
assert bboxes.format == format
from __future__ import annotations
from enum import Enum
from typing import Any, List, Optional, Sequence, Tuple, Union
import torch
from torchvision._utils import StrEnum
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
from ._datapoint import _FillTypeJIT, Datapoint
class BoundingBoxFormat(StrEnum):
XYXY = StrEnum.auto()
XYWH = StrEnum.auto()
CXCYWH = StrEnum.auto()
class BoundingBoxFormat(Enum):
XYXY = "XYXY"
XYWH = "XYWH"
CXCYWH = "CXCYWH"
class BoundingBox(Datapoint):
......@@ -39,7 +39,7 @@ class BoundingBox(Datapoint):
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper())
format = BoundingBoxFormat[format.upper()]
return cls._wrap(tensor, format=format, spatial_size=spatial_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment