"tests/vscode:/vscode.git/clone" did not exist on "02443df14f3c9d565570b1163b225e85d536c122"
Unverified Commit 8faa1b14 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Simplify query_bounding_boxes logic (#7786)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 9b82df43
...@@ -691,7 +691,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT ...@@ -691,7 +691,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
if isinstance(format, str): if isinstance(format, str):
format = datapoints.BoundingBoxFormat[format] format = datapoints.BoundingBoxFormat[format]
spatial_size = _parse_size(spatial_size, name="canvas_size") spatial_size = _parse_size(spatial_size, name="spatial_size")
def fn(shape, dtype, device): def fn(shape, dtype, device):
*batch_dims, num_coordinates = shape *batch_dims, num_coordinates = shape
...@@ -702,12 +702,12 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT ...@@ -702,12 +702,12 @@ def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORT
format=format, canvas_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device format=format, canvas_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
) )
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size) return BoundingBoxesLoader(fn, shape=(*extra_dims[-1:], 4), dtype=dtype, format=format, spatial_size=spatial_size)
def make_bounding_box_loaders( def make_bounding_box_loaders(
*, *,
extra_dims=DEFAULT_EXTRA_DIMS, extra_dims=tuple(d for d in DEFAULT_EXTRA_DIMS if len(d) < 2),
formats=tuple(datapoints.BoundingBoxFormat), formats=tuple(datapoints.BoundingBoxFormat),
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
dtypes=(torch.float32, torch.float64, torch.int64), dtypes=(torch.float32, torch.float64, torch.int64),
......
...@@ -22,7 +22,7 @@ def test_mask_instance(data): ...@@ -22,7 +22,7 @@ def test_mask_instance(data):
assert mask.ndim == 3 and mask.shape[0] == 1 assert mask.ndim == 3 and mask.shape[0] == 1
@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]]) @pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]], [1, 2, 3, 4]])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH] "format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
) )
...@@ -35,6 +35,12 @@ def test_bbox_instance(data, format): ...@@ -35,6 +35,12 @@ def test_bbox_instance(data, format):
assert bboxes.format == format assert bboxes.format == format
def test_bbox_dim_error():
data_3d = [[[1, 2, 3, 4]]]
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
datapoints.BoundingBoxes(data_3d, format="XYXY", canvas_size=(32, 32))
@pytest.mark.parametrize( @pytest.mark.parametrize(
("data", "input_requires_grad", "expected_requires_grad"), ("data", "input_requires_grad", "expected_requires_grad"),
[ [
......
...@@ -20,7 +20,7 @@ from prototype_common_utils import make_label ...@@ -20,7 +20,7 @@ from prototype_common_utils import make_label
from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms from torchvision.prototype import datapoints, transforms
from torchvision.transforms.v2._utils import _convert_fill_arg from torchvision.transforms.v2._utils import _convert_fill_arg
from torchvision.transforms.v2.functional import InterpolationMode, pil_to_tensor, to_image_pil from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil
from torchvision.transforms.v2.utils import check_type, is_simple_tensor from torchvision.transforms.v2.utils import check_type, is_simple_tensor
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
...@@ -306,7 +306,9 @@ class TestFixedSizeCrop: ...@@ -306,7 +306,9 @@ class TestFixedSizeCrop:
bounding_boxes = make_bounding_box( bounding_boxes = make_bounding_box(
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,) format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
) )
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes") mock = mocker.patch(
"torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes", wraps=clamp_bounding_boxes
)
transform = transforms.FixedSizeCrop((-1, -1)) transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
......
...@@ -1654,18 +1654,6 @@ def test_sanitize_bounding_boxes_errors(): ...@@ -1654,18 +1654,6 @@ def test_sanitize_bounding_boxes_errors():
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
transforms.SanitizeBoundingBoxes()(different_sizes) transforms.SanitizeBoundingBoxes()(different_sizes)
with pytest.raises(ValueError, match="boxes must be of shape"):
bad_bbox = datapoints.BoundingBoxes( # batch with 2 elements
[
[[0, 0, 10, 10]],
[[0, 0, 10, 10]],
],
format=datapoints.BoundingBoxFormat.XYXY,
canvas_size=(20, 20),
)
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(different_sizes)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"import_statement", "import_statement",
......
...@@ -711,21 +711,20 @@ def _parse_padding(padding): ...@@ -711,21 +711,20 @@ def _parse_padding(padding):
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]]) @pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
def test_correctness_pad_bounding_boxes(device, padding): def test_correctness_pad_bounding_boxes(device, padding):
def _compute_expected_bbox(bbox, padding_): def _compute_expected_bbox(bbox, format, padding_):
pad_left, pad_up, _, _ = _parse_padding(padding_) pad_left, pad_up, _, _ = _parse_padding(padding_)
dtype = bbox.dtype dtype = bbox.dtype
format = bbox.format
bbox = ( bbox = (
bbox.clone() bbox.clone()
if format == datapoints.BoundingBoxFormat.XYXY if format == datapoints.BoundingBoxFormat.XYXY
else convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) else convert_format_bounding_boxes(bbox, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
) )
bbox[0::2] += pad_left bbox[0::2] += pad_left
bbox[1::2] += pad_up bbox[1::2] += pad_up
bbox = convert_format_bounding_boxes(bbox, new_format=format) bbox = convert_format_bounding_boxes(bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format)
if bbox.dtype != dtype: if bbox.dtype != dtype:
# Temporary cast to original dtype # Temporary cast to original dtype
# e.g. float32 -> int # e.g. float32 -> int
...@@ -737,7 +736,7 @@ def test_correctness_pad_bounding_boxes(device, padding): ...@@ -737,7 +736,7 @@ def test_correctness_pad_bounding_boxes(device, padding):
height, width = bbox.canvas_size height, width = bbox.canvas_size
return height + pad_up + pad_down, width + pad_left + pad_right return height + pad_up + pad_down, width + pad_left + pad_right
for bboxes in make_bounding_boxes(): for bboxes in make_bounding_boxes(extra_dims=((4,),)):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_canvas_size = bboxes.canvas_size bboxes_canvas_size = bboxes.canvas_size
...@@ -748,18 +747,10 @@ def test_correctness_pad_bounding_boxes(device, padding): ...@@ -748,18 +747,10 @@ def test_correctness_pad_bounding_boxes(device, padding):
torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding)) torch.testing.assert_close(output_canvas_size, _compute_expected_canvas_size(bboxes, padding))
if bboxes.ndim < 2 or bboxes.shape[0] == 0: expected_bboxes = torch.stack(
bboxes = [bboxes] [_compute_expected_bbox(b, bboxes_format, padding) for b in bboxes.reshape(-1, 4).unbind()]
).reshape(bboxes.shape)
expected_bboxes = []
for bbox in bboxes:
bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size)
expected_bboxes.append(_compute_expected_bbox(bbox, padding))
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
...@@ -784,7 +775,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device): ...@@ -784,7 +775,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
], ],
) )
def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
def _compute_expected_bbox(bbox, pcoeffs_): def _compute_expected_bbox(bbox, format_, canvas_size_, pcoeffs_):
m1 = np.array( m1 = np.array(
[ [
[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]], [pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
...@@ -798,7 +789,9 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -798,7 +789,9 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
] ]
) )
bbox_xyxy = convert_format_bounding_boxes(bbox, new_format=datapoints.BoundingBoxFormat.XYXY) bbox_xyxy = convert_format_bounding_boxes(
bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY
)
points = np.array( points = np.array(
[ [
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
...@@ -818,14 +811,11 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -818,14 +811,11 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
np.max(transformed_points[:, 1]), np.max(transformed_points[:, 1]),
] ]
) )
out_bbox = datapoints.BoundingBoxes( out_bbox = torch.from_numpy(out_bbox)
out_bbox, out_bbox = convert_format_bounding_boxes(
format=datapoints.BoundingBoxFormat.XYXY, out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_
canvas_size=bbox.canvas_size,
dtype=bbox.dtype,
device=bbox.device,
) )
return clamp_bounding_boxes(convert_format_bounding_boxes(out_bbox, new_format=bbox.format)) return clamp_bounding_boxes(out_bbox, format=format_, canvas_size=canvas_size_).to(bbox)
canvas_size = (32, 38) canvas_size = (32, 38)
...@@ -844,17 +834,13 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -844,17 +834,13 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
coefficients=pcoeffs, coefficients=pcoeffs,
) )
if bboxes.ndim < 2: expected_bboxes = torch.stack(
bboxes = [bboxes] [
_compute_expected_bbox(b, bboxes.format, bboxes.canvas_size, inv_pcoeffs)
for b in bboxes.reshape(-1, 4).unbind()
]
).reshape(bboxes.shape)
expected_bboxes = []
for bbox in bboxes:
bbox = datapoints.BoundingBoxes(bbox, format=bboxes.format, canvas_size=bboxes.canvas_size)
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1) torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=0, atol=1)
...@@ -864,9 +850,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints): ...@@ -864,9 +850,7 @@ def test_correctness_perspective_bounding_boxes(device, startpoints, endpoints):
[(18, 18), [18, 15], (16, 19), [12], [46, 48]], [(18, 18), [18, 15], (16, 19), [12], [46, 48]],
) )
def test_correctness_center_crop_bounding_boxes(device, output_size): def test_correctness_center_crop_bounding_boxes(device, output_size):
def _compute_expected_bbox(bbox, output_size_): def _compute_expected_bbox(bbox, format_, canvas_size_, output_size_):
format_ = bbox.format
canvas_size_ = bbox.canvas_size
dtype = bbox.dtype dtype = bbox.dtype
bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH) bbox = convert_format_bounding_boxes(bbox.float(), format_, datapoints.BoundingBoxFormat.XYWH)
...@@ -895,18 +879,12 @@ def test_correctness_center_crop_bounding_boxes(device, output_size): ...@@ -895,18 +879,12 @@ def test_correctness_center_crop_bounding_boxes(device, output_size):
bboxes, bboxes_format, bboxes_canvas_size, output_size bboxes, bboxes_format, bboxes_canvas_size, output_size
) )
if bboxes.ndim < 2: expected_bboxes = torch.stack(
bboxes = [bboxes] [
_compute_expected_bbox(b, bboxes_format, bboxes_canvas_size, output_size)
expected_bboxes = [] for b in bboxes.reshape(-1, 4).unbind()
for bbox in bboxes: ]
bbox = datapoints.BoundingBoxes(bbox, format=bboxes_format, canvas_size=bboxes_canvas_size) ).reshape(bboxes.shape)
expected_bboxes.append(_compute_expected_bbox(bbox, output_size))
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0) torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
torch.testing.assert_close(output_canvas_size, output_size) torch.testing.assert_close(output_canvas_size, output_size)
......
...@@ -222,16 +222,9 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz ...@@ -222,16 +222,9 @@ def reference_affine_bounding_boxes_helper(bounding_boxes, *, format, canvas_siz
out_bbox = out_bbox.to(dtype=in_dtype) out_bbox = out_bbox.to(dtype=in_dtype)
return out_bbox return out_bbox
if bounding_boxes.ndim < 2: return torch.stack(
bounding_boxes = [bounding_boxes] [transform(b, affine_matrix, format, canvas_size) for b in bounding_boxes.reshape(-1, 4).unbind()]
).reshape(bounding_boxes.shape)
expected_bboxes = [transform(bbox, affine_matrix, format, canvas_size) for bbox in bounding_boxes]
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
return expected_bboxes
def sample_inputs_convert_format_bounding_boxes(): def sample_inputs_convert_format_bounding_boxes():
......
...@@ -26,6 +26,12 @@ class BoundingBoxFormat(Enum): ...@@ -26,6 +26,12 @@ class BoundingBoxFormat(Enum):
class BoundingBoxes(Datapoint): class BoundingBoxes(Datapoint):
"""[BETA] :class:`torch.Tensor` subclass for bounding boxes. """[BETA] :class:`torch.Tensor` subclass for bounding boxes.
.. note::
There should be only one :class:`~torchvision.datapoints.BoundingBoxes`
instance per sample e.g. ``{"img": img, "bbox": BoundingBoxes(...)}``,
although one :class:`~torchvision.datapoints.BoundingBoxes` object can
contain multiple bounding boxes.
Args: Args:
data: Any data that can be turned into a tensor with :func:`torch.as_tensor`. data: Any data that can be turned into a tensor with :func:`torch.as_tensor`.
format (BoundingBoxFormat, str): Format of the bounding box. format (BoundingBoxFormat, str): Format of the bounding box.
...@@ -43,6 +49,10 @@ class BoundingBoxes(Datapoint): ...@@ -43,6 +49,10 @@ class BoundingBoxes(Datapoint):
@classmethod @classmethod
def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override] def _wrap(cls, tensor: torch.Tensor, *, format: Union[BoundingBoxFormat, str], canvas_size: Tuple[int, int]) -> BoundingBoxes: # type: ignore[override]
if tensor.ndim == 1:
tensor = tensor.unsqueeze(0)
elif tensor.ndim != 2:
raise ValueError(f"Expected a 1D or 2D tensor, got {tensor.ndim}D")
if isinstance(format, str): if isinstance(format, str):
format = BoundingBoxFormat[format.upper()] format = BoundingBoxFormat[format.upper()]
bounding_boxes = tensor.as_subclass(cls) bounding_boxes = tensor.as_subclass(cls)
......
...@@ -7,7 +7,7 @@ from torchvision import datapoints ...@@ -7,7 +7,7 @@ from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size
class FixedSizeCrop(Transform): class FixedSizeCrop(Transform):
...@@ -61,7 +61,7 @@ class FixedSizeCrop(Transform): ...@@ -61,7 +61,7 @@ class FixedSizeCrop(Transform):
bounding_boxes: Optional[torch.Tensor] bounding_boxes: Optional[torch.Tensor]
try: try:
bounding_boxes = query_bounding_boxes(flat_inputs) bounding_boxes = get_bounding_boxes(flat_inputs)
except ValueError: except ValueError:
bounding_boxes = None bounding_boxes = None
......
...@@ -23,7 +23,7 @@ from ._utils import ( ...@@ -23,7 +23,7 @@ from ._utils import (
_setup_float_or_seq, _setup_float_or_seq,
_setup_size, _setup_size,
) )
from .utils import has_all, has_any, is_simple_tensor, query_bounding_boxes, query_size from .utils import get_bounding_boxes, has_all, has_any, is_simple_tensor, query_size
class RandomHorizontalFlip(_RandomApplyTransform): class RandomHorizontalFlip(_RandomApplyTransform):
...@@ -1137,7 +1137,7 @@ class RandomIoUCrop(Transform): ...@@ -1137,7 +1137,7 @@ class RandomIoUCrop(Transform):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
orig_h, orig_w = query_size(flat_inputs) orig_h, orig_w = query_size(flat_inputs)
bboxes = query_bounding_boxes(flat_inputs) bboxes = get_bounding_boxes(flat_inputs)
while True: while True:
# sample an option # sample an option
......
...@@ -10,7 +10,7 @@ from torchvision import datapoints, transforms as _transforms ...@@ -10,7 +10,7 @@ from torchvision import datapoints, transforms as _transforms
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size
from .utils import has_any, is_simple_tensor, query_bounding_boxes from .utils import get_bounding_boxes, has_any, is_simple_tensor
# TODO: do we want/need to expose this? # TODO: do we want/need to expose this?
...@@ -384,13 +384,7 @@ class SanitizeBoundingBoxes(Transform): ...@@ -384,13 +384,7 @@ class SanitizeBoundingBoxes(Transform):
) )
flat_inputs, spec = tree_flatten(inputs) flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBoxes entry. boxes = get_bounding_boxes(flat_inputs)
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes = query_bounding_boxes(flat_inputs)
if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")
if labels is not None and boxes.shape[0] != labels.shape[0]: if labels is not None and boxes.shape[0] != labels.shape[0]:
raise ValueError( raise ValueError(
......
...@@ -9,13 +9,12 @@ from torchvision._utils import sequence_to_str ...@@ -9,13 +9,12 @@ from torchvision._utils import sequence_to_str
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor
def query_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)] # This assumes there is only one bbox per sample as per the general convention
if not bounding_boxes: try:
raise TypeError("No bounding boxes were found in the sample") return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes))
elif len(bounding_boxes) > 1: except StopIteration:
raise ValueError("Found multiple bounding boxes instances in the sample") raise ValueError("No bounding boxes were found in the sample")
return bounding_boxes.pop()
def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment