Unverified Commit ea37cd38 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make convert_format_bounding_box a hybrid kernel dispatcher (#7228)

parent 0316ed10
......@@ -237,6 +237,13 @@ class TensorLoader:
def load(self, device):
return self.fn(self.shape, self.dtype, device)
def unwrap(self):
return TensorLoader(
fn=lambda shape, dtype, device: self.fn(shape, dtype, device).as_subclass(torch.Tensor),
shape=self.shape,
dtype=self.dtype,
)
@dataclasses.dataclass
class ImageLoader(TensorLoader):
......
......@@ -26,7 +26,6 @@ from prototype_common_utils import (
make_video_loader,
make_video_loaders,
mark_framework_limitation,
TensorLoader,
TestMark,
)
from torch.utils._pytree import tree_map
......@@ -660,7 +659,8 @@ KERNEL_INFOS.extend(
def sample_inputs_convert_format_bounding_box():
formats = list(datapoints.BoundingBoxFormat)
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
yield ArgsKwargs(bounding_box_loader, new_format=new_format)
yield ArgsKwargs(bounding_box_loader.unwrap(), old_format=bounding_box_loader.format, new_format=new_format)
def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
......@@ -671,8 +671,14 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
def reference_inputs_convert_format_bounding_box():
for args_kwargs in sample_inputs_convert_format_bounding_box():
if len(args_kwargs.args[0].shape) == 2:
yield args_kwargs
if len(args_kwargs.args[0].shape) != 2:
continue
(loader, *other_args), kwargs = args_kwargs
if isinstance(loader, BoundingBoxLoader):
kwargs["old_format"] = loader.format
loader = loader.unwrap()
yield ArgsKwargs(loader, *other_args, **kwargs)
KERNEL_INFOS.append(
......@@ -682,6 +688,18 @@ KERNEL_INFOS.append(
reference_fn=reference_convert_format_bounding_box,
reference_inputs_fn=reference_inputs_convert_format_bounding_box,
logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("old_format") is None,
)
],
),
)
......@@ -2014,13 +2032,10 @@ def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(bounding_box_loader)
simple_tensor_loader = TensorLoader(
fn=lambda shape, dtype, device: bounding_box_loader.fn(shape, dtype, device).as_subclass(torch.Tensor),
shape=bounding_box_loader.shape,
dtype=bounding_box_loader.dtype,
)
yield ArgsKwargs(
simple_tensor_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
bounding_box_loader.unwrap(),
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
)
......
......@@ -572,7 +572,7 @@ class TestClampBoundingBox:
def test_simple_tensor_insufficient_metadata(self, metadata):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match="simple tensor"):
with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` has to be passed")):
F.clamp_bounding_box(simple_tensor, **metadata)
@pytest.mark.parametrize(
......@@ -586,10 +586,37 @@ class TestClampBoundingBox:
def test_datapoint_explicit_metadata(self, metadata):
datapoint = next(make_bounding_boxes())
with pytest.raises(ValueError, match="bounding box datapoint"):
with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` must not be passed")):
F.clamp_bounding_box(datapoint, **metadata)
class TestConvertFormatBoundingBox:
@pytest.mark.parametrize(
("inpt", "old_format"),
[
(next(make_bounding_boxes()), None),
(next(make_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY),
],
)
def test_missing_new_format(self, inpt, old_format):
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_format_bounding_box(inpt, old_format)
def test_simple_tensor_insufficient_metadata(self):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_format_bounding_box(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
def test_datapoint_explicit_metadata(self):
datapoint = next(make_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
F.convert_format_bounding_box(
datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
)
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py`
......
......@@ -19,12 +19,7 @@ class ConvertBoundingBoxFormat(Transform):
self.format = format
def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
# We need to unwrap here to avoid unnecessary `__torch_function__` calls,
# since `convert_format_bounding_box` does not have a dispatcher function that would do that for us
output = F.convert_format_bounding_box(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=params["format"]
)
return datapoints.BoundingBox.wrap_like(inpt, output, format=params["format"])
return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value]
class ConvertDtype(Transform):
......
......@@ -186,11 +186,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
return xyxy
def convert_format_bounding_box(
def _convert_format_bounding_box(
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(convert_format_bounding_box)
if new_format == old_format:
return bounding_box
......@@ -209,6 +207,37 @@ def convert_format_bounding_box(
return bounding_box
def convert_format_bounding_box(
inpt: datapoints.InputTypeJIT,
old_format: Optional[BoundingBoxFormat] = None,
new_format: Optional[BoundingBoxFormat] = None,
inplace: bool = False,
) -> datapoints.InputTypeJIT:
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBox` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
if new_format is None:
raise TypeError("convert_format_bounding_box() missing 1 required argument: 'new_format'")
if not torch.jit.is_scripting():
_log_api_usage_once(convert_format_bounding_box)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if old_format is None:
raise ValueError("For simple tensor inputs, `old_format` has to be passed.")
return _convert_format_bounding_box(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
elif isinstance(inpt, datapoints.BoundingBox):
if old_format is not None:
raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.")
output = _convert_format_bounding_box(inpt, old_format=inpt.format, new_format=new_format, inplace=inplace)
return datapoints.BoundingBox.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
)
def _clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment