Unverified Commit 0316ed10 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make clamp_bounding_box a kernel / dispatcher hybrid (#7227)

parent 2489f370
...@@ -640,14 +640,14 @@ class TestMark: ...@@ -640,14 +640,14 @@ class TestMark:
self.condition = condition or (lambda args_kwargs: True) self.condition = condition or (lambda args_kwargs: True)
def mark_framework_limitation(test_id, reason): def mark_framework_limitation(test_id, reason, condition=None):
# The purpose of this function is to have a single entry point for skip marks that are only there, because the test # The purpose of this function is to have a single entry point for skip marks that are only there, because the test
# framework cannot handle the kernel in general or a specific parameter combination. # framework cannot handle the kernel in general or a specific parameter combination.
# As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is # As development progresses, we can change the `mark.skip` to `mark.xfail` from time to time to see if the skip is
# still justified. # still justified.
# We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus, # We don't want to use `mark.xfail` all the time, because that actually runs the test until an error happens. Thus,
# we are wasting CI resources for no reason for most of the time # we are wasting CI resources for no reason for most of the time
return TestMark(test_id, pytest.mark.skip(reason=reason)) return TestMark(test_id, pytest.mark.skip(reason=reason), condition=condition)
class InfoBase: class InfoBase:
......
...@@ -12,6 +12,7 @@ import torchvision.prototype.transforms.functional as F ...@@ -12,6 +12,7 @@ import torchvision.prototype.transforms.functional as F
from datasets_utils import combinations_grid from datasets_utils import combinations_grid
from prototype_common_utils import ( from prototype_common_utils import (
ArgsKwargs, ArgsKwargs,
BoundingBoxLoader,
get_num_channels, get_num_channels,
ImageLoader, ImageLoader,
InfoBase, InfoBase,
...@@ -25,6 +26,7 @@ from prototype_common_utils import ( ...@@ -25,6 +26,7 @@ from prototype_common_utils import (
make_video_loader, make_video_loader,
make_video_loaders, make_video_loaders,
mark_framework_limitation, mark_framework_limitation,
TensorLoader,
TestMark, TestMark,
) )
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -2010,8 +2012,15 @@ KERNEL_INFOS.extend( ...@@ -2010,8 +2012,15 @@ KERNEL_INFOS.extend(
def sample_inputs_clamp_bounding_box(): def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(bounding_box_loader)
simple_tensor_loader = TensorLoader(
fn=lambda shape, dtype, device: bounding_box_loader.fn(shape, dtype, device).as_subclass(torch.Tensor),
shape=bounding_box_loader.shape,
dtype=bounding_box_loader.dtype,
)
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size simple_tensor_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
) )
...@@ -2020,6 +2029,19 @@ KERNEL_INFOS.append( ...@@ -2020,6 +2029,19 @@ KERNEL_INFOS.append(
F.clamp_bounding_box, F.clamp_bounding_box,
sample_inputs_fn=sample_inputs_clamp_bounding_box, sample_inputs_fn=sample_inputs_clamp_bounding_box,
logs_usage=True, logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("format") is None
and arg_kwargs.kwargs.get("spatial_size") is None,
)
],
) )
) )
......
...@@ -155,12 +155,14 @@ class TestKernels: ...@@ -155,12 +155,14 @@ class TestKernels:
if batched_tensor.ndim == data_dims: if batched_tensor.ndim == data_dims:
return batch return batch
return [ unbatcheds = []
self._unbatch(unbatched, data_dims=data_dims) for unbatched in (
for unbatched in ( batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)] ):
) if isinstance(batch, datapoints._datapoint.Datapoint):
] unbatched = type(batch).wrap_like(batch, unbatched)
unbatcheds.append(self._unbatch(unbatched, data_dims=data_dims))
return unbatcheds
@sample_inputs @sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -558,6 +560,36 @@ def test_normalize_image_tensor_stats(device, num_channels): ...@@ -558,6 +560,36 @@ def test_normalize_image_tensor_stats(device, num_channels):
assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std)) assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std))
class TestClampBoundingBox:
@pytest.mark.parametrize(
"metadata",
[
dict(),
dict(format=datapoints.BoundingBoxFormat.XYXY),
dict(spatial_size=(1, 1)),
],
)
def test_simple_tensor_insufficient_metadata(self, metadata):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match="simple tensor"):
F.clamp_bounding_box(simple_tensor, **metadata)
@pytest.mark.parametrize(
"metadata",
[
dict(format=datapoints.BoundingBoxFormat.XYXY),
dict(spatial_size=(1, 1)),
dict(format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(1, 1)),
],
)
def test_datapoint_explicit_metadata(self, metadata):
datapoint = next(make_bounding_boxes())
with pytest.raises(ValueError, match="bounding box datapoint"):
F.clamp_bounding_box(datapoint, **metadata)
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in # TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py` # `prototype_transforms_kernel_infos.py`
......
...@@ -51,9 +51,4 @@ class ClampBoundingBoxes(Transform): ...@@ -51,9 +51,4 @@ class ClampBoundingBoxes(Transform):
_transformed_types = (datapoints.BoundingBox,) _transformed_types = (datapoints.BoundingBox,)
def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
# We need to unwrap here to avoid unnecessary `__torch_function__` calls, return F.clamp_bounding_box(inpt) # type: ignore[return-value]
# since `clamp_bounding_box` does not have a dispatcher function that would do that for us
output = F.clamp_bounding_box(
inpt.as_subclass(torch.Tensor), format=inpt.format, spatial_size=inpt.spatial_size
)
return datapoints.BoundingBox.wrap_like(inpt, output)
from typing import List, Tuple, Union from typing import List, Optional, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -209,12 +209,9 @@ def convert_format_bounding_box( ...@@ -209,12 +209,9 @@ def convert_format_bounding_box(
return bounding_box return bounding_box
def clamp_bounding_box( def _clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int] bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_box)
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth # BoundingBoxFormat instead of converting back and forth
xyxy_boxes = convert_format_bounding_box( xyxy_boxes = convert_format_bounding_box(
...@@ -225,6 +222,29 @@ def clamp_bounding_box( ...@@ -225,6 +222,29 @@ def clamp_bounding_box(
return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True) return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True)
def clamp_bounding_box(
inpt: datapoints.InputTypeJIT,
format: Optional[BoundingBoxFormat] = None,
spatial_size: Optional[Tuple[int, int]] = None,
) -> datapoints.InputTypeJIT:
if not torch.jit.is_scripting():
_log_api_usage_once(clamp_bounding_box)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if format is None or spatial_size is None:
raise ValueError("For simple tensor inputs, `format` and `spatial_size` has to be passed.")
return _clamp_bounding_box(inpt, format=format, spatial_size=spatial_size)
elif isinstance(inpt, datapoints.BoundingBox):
if format is not None or spatial_size is not None:
raise ValueError("For bounding box datapoint inputs, `format` and `spatial_size` must not be passed.")
output = _clamp_bounding_box(inpt, format=inpt.format, spatial_size=inpt.spatial_size)
return datapoints.BoundingBox.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
)
def _num_value_bits(dtype: torch.dtype) -> int: def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8: if dtype == torch.uint8:
return 8 return 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment