"...python/git@developer.sourcefind.cn:change/sglang.git" did not exist on "3980ff1be6fe2ffb8b2ee1d2a9d3f71a48a42135"
Unverified Commit 449cc090 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

remove strEnum from BoundingBoxFormat (#7322)

parent 7fefdea3
...@@ -28,5 +28,5 @@ def test_bbox_instance(data, format): ...@@ -28,5 +28,5 @@ def test_bbox_instance(data, format):
assert isinstance(bboxes, torch.Tensor) assert isinstance(bboxes, torch.Tensor)
assert bboxes.ndim == 2 and bboxes.shape[1] == 4 assert bboxes.ndim == 2 and bboxes.shape[1] == 4
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat.from_str(format.upper()) format = datapoints.BoundingBoxFormat[(format.upper())]
assert bboxes.format == format assert bboxes.format == format
from __future__ import annotations from __future__ import annotations
from enum import Enum
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import Any, List, Optional, Sequence, Tuple, Union
import torch import torch
from torchvision._utils import StrEnum
from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms from torchvision.transforms import InterpolationMode # TODO: this needs to be moved out of transforms
from ._datapoint import _FillTypeJIT, Datapoint from ._datapoint import _FillTypeJIT, Datapoint
class BoundingBoxFormat(StrEnum): class BoundingBoxFormat(Enum):
XYXY = StrEnum.auto() XYXY = "XYXY"
XYWH = StrEnum.auto() XYWH = "XYWH"
CXCYWH = StrEnum.auto() CXCYWH = "CXCYWH"
class BoundingBox(Datapoint): class BoundingBox(Datapoint):
...@@ -39,7 +39,7 @@ class BoundingBox(Datapoint): ...@@ -39,7 +39,7 @@ class BoundingBox(Datapoint):
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
if isinstance(format, str): if isinstance(format, str):
format = BoundingBoxFormat.from_str(format.upper()) format = BoundingBoxFormat[format.upper()]
return cls._wrap(tensor, format=format, spatial_size=spatial_size) return cls._wrap(tensor, format=format, spatial_size=spatial_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment