Unverified Commit 0316ed10 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make clamp_bounding_box a kernel / dispatcher hybrid (#7227)

parent 2489f370
......@@ -640,14 +640,14 @@ class TestMark:
self.condition = condition or (lambda args_kwargs: True)
def mark_framework_limitation(test_id, reason):
def mark_framework_limitation(test_id, reason, condition=None):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time
return TestMark(test_id, pytest.mark.skip(reason=reason))
return TestMark(test_id, pytest.mark.skip(reason=reason), condition=condition)
class InfoBase:
......
......@@ -12,6 +12,7 @@ import torchvision.prototype.transforms.functional as F
from datasets_utils import combinations_grid
from prototype_common_utils import (
ArgsKwargs,
BoundingBoxLoader,
get_num_channels,
ImageLoader,
InfoBase,
......@@ -25,6 +26,7 @@ from prototype_common_utils import (
make_video_loader,
make_video_loaders,
mark_framework_limitation,
TensorLoader,
TestMark,
)
from torch.utils._pytree import tree_map
......@@ -2010,8 +2012,15 @@ KERNEL_INFOS.extend(
def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(bounding_box_loader)
simple_tensor_loader = TensorLoader(
fn=lambda shape, dtype, device: bounding_box_loader.fn(shape, dtype, device).as_subclass(torch.Tensor),
shape=bounding_box_loader.shape,
dtype=bounding_box_loader.dtype,
)
yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
simple_tensor_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
)
......@@ -2020,6 +2029,19 @@ KERNEL_INFOS.append(
F.clamp_bounding_box,
sample_inputs_fn=sample_inputs_clamp_bounding_box,
logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("format") is None
and arg_kwargs.kwargs.get("spatial_size") is None,
)
],
)
)
......
......@@ -155,12 +155,14 @@ class TestKernels:
if batched_tensor.ndim == data_dims:
return batch
return [
self._unbatch(unbatched, data_dims=data_dims)
unbatcheds = []
for unbatched in (
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
)
]
):
if isinstance(batch, datapoints._datapoint.Datapoint):
unbatched = type(batch).wrap_like(batch, unbatched)
unbatcheds.append(self._unbatch(unbatched, data_dims=data_dims))
return unbatcheds
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
......@@ -558,6 +560,36 @@ def test_normalize_image_tensor_stats(device, num_channels):
assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))
class TestClampBoundingBox:
@pytest.mark.parametrize(
"metadata",
[
dict(),
dict(format=datapoints.BoundingBoxFormat.XYXY),
dict(spatial_size=(1, 1)),
],
)
def test_simple_tensor_insufficient_metadata(self, metadata):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match="simple tensor"):
F.clamp_bounding_box(simple_tensor, **metadata)
@pytest.mark.parametrize(
"metadata",
[
dict(format=datapoints.BoundingBoxFormat.XYXY),
dict(spatial_size=(1, 1)),
dict(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(1, 1)),
],
)
def test_datapoint_explicit_metadata(self, metadata):
datapoint = next(make_bounding_boxes())
with pytest.raises(ValueError, match="bounding box datapoint"):
F.clamp_bounding_box(datapoint, **metadata)
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py`
......
......@@ -51,9 +51,4 @@ class ClampBoundingBoxes(Transform):
_transformed_types = (datapoints.BoundingBox,)
def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
# We need to unwrap here to avoid unnecessary `__torch_function__` calls,
# since `clamp_bounding_box` does not have a dispatcher function that would do that for us
output = F.clamp_bounding_box(
inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size
)
return datapoints.BoundingBox.wrap_like(inpt, output)
return F.clamp_bounding_box(inpt) # type: ignore[return-value]
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union
import PIL.Image
import torch
......@@ -209,12 +209,9 @@ def convert_format_bounding_box(
return bounding_box
def clamp_bounding_box(
def _clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_box)
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
xyxy_boxes = convert_format_bounding_box(
......@@ -225,6 +222,29 @@ def clamp_bounding_box(
return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True)
def clamp_bounding_box(
inpt: datapoints.InputTypeJIT,
format: Optional[BoundingBoxFormat] = None,
spatial_size: Optional[Tuple[int, int]] = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_box)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if format is None or spatial_size is None:
raise ValueError("For simple tensor inputs, `format` and `spatial_size` has to be passed.")
return _clamp_bounding_box(inpt, format=format, spatial_size=spatial_size)
elif isinstance(inpt, datapoints.BoundingBox):
if format is not None or spatial_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.")
output = _clamp_bounding_box(inpt, format=inpt.format, spatial_size=inpt.spatial_size)
return datapoints.BoundingBox.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
)
def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment