Unverified Commit ea37cd38 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make convert_format_bounding_box a hybrid kernel dispatcher (#7228)

parent 0316ed10
...@@ -237,6 +237,13 @@ class TensorLoader: ...@@ -237,6 +237,13 @@ class TensorLoader:
def load(self, device): def load(self, device):
return self.fn(self.shape, self.dtype, device) return self.fn(self.shape, self.dtype, device)
def unwrap(self):
return TensorLoader(
fn=lambda shape, dtype, device: self.fn(shape, dtype, device).as_subclass(torch.Tensor),
shape=self.shape,
dtype=self.dtype,
)
@dataclasses.dataclass @dataclasses.dataclass
class ImageLoader(TensorLoader): class ImageLoader(TensorLoader):
......
...@@ -26,7 +26,6 @@ from prototype_common_utils import ( ...@@ -26,7 +26,6 @@ from prototype_common_utils import (
make_video_loader, make_video_loader,
make_video_loaders, make_video_loaders,
mark_framework_limitation, mark_framework_limitation,
TensorLoader,
TestMark, TestMark,
) )
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
...@@ -660,7 +659,8 @@ KERNEL_INFOS.extend( ...@@ -660,7 +659,8 @@ KERNEL_INFOS.extend(
def sample_inputs_convert_format_bounding_box(): def sample_inputs_convert_format_bounding_box():
formats = list(datapoints.BoundingBoxFormat) formats = list(datapoints.BoundingBoxFormat)
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats): for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format) yield ArgsKwargs(bounding_box_loader, new_format=new_format)
yield ArgsKwargs(bounding_box_loader.unwrap(), old_format=bounding_box_loader.format, new_format=new_format)
def reference_convert_format_bounding_box(bounding_box, old_format, new_format): def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
...@@ -671,8 +671,14 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format): ...@@ -671,8 +671,14 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
def reference_inputs_convert_format_bounding_box(): def reference_inputs_convert_format_bounding_box():
for args_kwargs in sample_inputs_convert_format_bounding_box(): for args_kwargs in sample_inputs_convert_format_bounding_box():
if len(args_kwargs.args[0].shape) == 2: if len(args_kwargs.args[0].shape) != 2:
yield args_kwargs continue
(loader, *other_args), kwargs = args_kwargs
if isinstance(loader, BoundingBoxLoader):
kwargs["old_format"] = loader.format
loader = loader.unwrap()
yield ArgsKwargs(loader, *other_args, **kwargs)
KERNEL_INFOS.append( KERNEL_INFOS.append(
...@@ -682,6 +688,18 @@ KERNEL_INFOS.append( ...@@ -682,6 +688,18 @@ KERNEL_INFOS.append(
reference_fn=reference_convert_format_bounding_box, reference_fn=reference_convert_format_bounding_box,
reference_inputs_fn=reference_inputs_convert_format_bounding_box, reference_inputs_fn=reference_inputs_convert_format_bounding_box,
logs_usage=True, logs_usage=True,
test_marks=[
mark_framework_limitation(
("TestKernels", "test_scripted_vs_eager"),
reason=(
"The function is hybrid kernel / dispatcher. JIT unwraps a `datapoints.BoundingBox` into a "
"`torch.Tensor`, but then the kernel (rightfully) complains that neither `format` nor "
"`spatial_size` was passed"
),
condition=lambda arg_kwargs: isinstance(arg_kwargs.args[0], BoundingBoxLoader)
and arg_kwargs.kwargs.get("old_format") is None,
)
],
), ),
) )
...@@ -2014,13 +2032,10 @@ def sample_inputs_clamp_bounding_box(): ...@@ -2014,13 +2032,10 @@ def sample_inputs_clamp_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(bounding_box_loader) yield ArgsKwargs(bounding_box_loader)
simple_tensor_loader = TensorLoader(
fn=lambda shape, dtype, device: bounding_box_loader.fn(shape, dtype, device).as_subclass(torch.Tensor),
shape=bounding_box_loader.shape,
dtype=bounding_box_loader.dtype,
)
yield ArgsKwargs( yield ArgsKwargs(
simple_tensor_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size bounding_box_loader.unwrap(),
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
) )
......
...@@ -572,7 +572,7 @@ class TestClampBoundingBox: ...@@ -572,7 +572,7 @@ class TestClampBoundingBox:
def test_simple_tensor_insufficient_metadata(self, metadata): def test_simple_tensor_insufficient_metadata(self, metadata):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match="simple tensor"): with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` has to be passed")):
F.clamp_bounding_box(simple_tensor, **metadata) F.clamp_bounding_box(simple_tensor, **metadata)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -586,10 +586,37 @@ class TestClampBoundingBox: ...@@ -586,10 +586,37 @@ class TestClampBoundingBox:
def test_datapoint_explicit_metadata(self, metadata): def test_datapoint_explicit_metadata(self, metadata):
datapoint = next(make_bounding_boxes()) datapoint = next(make_bounding_boxes())
with pytest.raises(ValueError, match="bounding box datapoint"): with pytest.raises(ValueError, match=re.escape("`format` and `spatial_size` must not be passed")):
F.clamp_bounding_box(datapoint, **metadata) F.clamp_bounding_box(datapoint, **metadata)
class TestConvertFormatBoundingBox:
@pytest.mark.parametrize(
("inpt", "old_format"),
[
(next(make_bounding_boxes()), None),
(next(make_bounding_boxes()).as_subclass(torch.Tensor), datapoints.BoundingBoxFormat.XYXY),
],
)
def test_missing_new_format(self, inpt, old_format):
with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")):
F.convert_format_bounding_box(inpt, old_format)
def test_simple_tensor_insufficient_metadata(self):
simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor)
with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")):
F.convert_format_bounding_box(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH)
def test_datapoint_explicit_metadata(self):
datapoint = next(make_bounding_boxes())
with pytest.raises(ValueError, match=re.escape("`old_format` must not be passed")):
F.convert_format_bounding_box(
datapoint, old_format=datapoint.format, new_format=datapoints.BoundingBoxFormat.CXCYWH
)
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in # TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
# `prototype_transforms_kernel_infos.py` # `prototype_transforms_kernel_infos.py`
......
...@@ -19,12 +19,7 @@ class ConvertBoundingBoxFormat(Transform): ...@@ -19,12 +19,7 @@ class ConvertBoundingBoxFormat(Transform):
self.format = format self.format = format
def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
# We need to unwrap here to avoid unnecessary `__torch_function__` calls, return F.convert_format_bounding_box(inpt, new_format=self.format) # type: ignore[return-value]
# since `convert_format_bounding_box` does not have a dispatcher function that would do that for us
output = F.convert_format_bounding_box(
inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=params["format"]
)
return datapoints.BoundingBox.wrap_like(inpt, output, format=params["format"])
class ConvertDtype(Transform): class ConvertDtype(Transform):
......
...@@ -186,11 +186,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor: ...@@ -186,11 +186,9 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
return xyxy return xyxy
def convert_format_bounding_box( def _convert_format_bounding_box(
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
if not torch.jit.is_scripting():
_log_api_usage_once(convert_format_bounding_box)
if new_format == old_format: if new_format == old_format:
return bounding_box return bounding_box
...@@ -209,6 +207,37 @@ def convert_format_bounding_box( ...@@ -209,6 +207,37 @@ def convert_format_bounding_box(
return bounding_box return bounding_box
def convert_format_bounding_box(
inpt: datapoints.InputTypeJIT,
old_format: Optional[BoundingBoxFormat] = None,
new_format: Optional[BoundingBoxFormat] = None,
inplace: bool = False,
) -> datapoints.InputTypeJIT:
# This being a kernel / dispatcher hybrid, we need an option to pass `old_format` explicitly for simple tensor
# inputs as well as extract it from `datapoints.BoundingBox` inputs. However, putting a default value on
# `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the
# default error that would be thrown if `new_format` had no default value.
if new_format is None:
raise TypeError("convert_format_bounding_box() missing 1 required argument: 'new_format'")
if not torch.jit.is_scripting():
_log_api_usage_once(convert_format_bounding_box)
if torch.jit.is_scripting() or is_simple_tensor(inpt):
if old_format is None:
raise ValueError("For simple tensor inputs, `old_format` has to be passed.")
return _convert_format_bounding_box(inpt, old_format=old_format, new_format=new_format, inplace=inplace)
elif isinstance(inpt, datapoints.BoundingBox):
if old_format is not None:
raise ValueError("For bounding box datapoint inputs, `old_format` must not be passed.")
output = _convert_format_bounding_box(inpt, old_format=inpt.format, new_format=new_format, inplace=inplace)
return datapoints.BoundingBox.wrap_like(inpt, output)
else:
raise TypeError(
f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead."
)
def _clamp_bounding_box( def _clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int] bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment