Unverified Commit ae83c9fd authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

[PoC] move metadata computation from prototype features into kernels (#6646)

* move metadata computation from prototype features into kernels

* fix tests

* fix no_inplace test

* mypy

* add perf TODO
parent 2907c494
......@@ -709,7 +709,7 @@ def sample_inputs_crop_bounding_box():
for bounding_box_loader, params in itertools.product(
make_bounding_box_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
):
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, top=params["top"], left=params["left"])
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)
def sample_inputs_crop_mask():
......@@ -856,7 +856,9 @@ def sample_inputs_pad_bounding_box():
if params["padding_mode"] != "constant":
continue
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)
yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size, **params
)
def sample_inputs_pad_mask():
......@@ -1552,8 +1554,6 @@ KERNEL_INFOS.extend(
skips=[
skip_integer_size_jit(),
Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."),
Skip("test_no_inplace", reason="Output of five_crop_image_tensor is not a tensor."),
Skip("test_dtype_and_device_consistency", reason="Output of five_crop_image_tensor is not a tensor."),
],
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
......@@ -1565,8 +1565,6 @@ KERNEL_INFOS.extend(
skips=[
skip_integer_size_jit(),
Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."),
Skip("test_no_inplace", reason="Output of ten_crop_image_tensor is not a tensor."),
Skip("test_dtype_and_device_consistency", reason="Output of ten_crop_image_tensor is not a tensor."),
],
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
),
......
......@@ -68,17 +68,22 @@ class TestKernels:
assert_close(actual, expected, **info.closeness_kwargs)
def _unbind_batch_dims(self, batched_tensor, *, data_dims):
if batched_tensor.ndim == data_dims:
return batched_tensor
return [self._unbind_batch_dims(t, data_dims=data_dims) for t in batched_tensor.unbind(0)]
def _unbatch(self, batch, *, data_dims):
if isinstance(batch, torch.Tensor):
batched_tensor = batch
metadata = ()
else:
batched_tensor, *metadata = batch
def _stack_batch_dims(self, unbound_tensor):
if isinstance(unbound_tensor[0], torch.Tensor):
return torch.stack(unbound_tensor)
if batched_tensor.ndim == data_dims:
return batch
return torch.stack([self._stack_batch_dims(t) for t in unbound_tensor])
return [
self._unbatch(unbatched, data_dims=data_dims)
for unbatched in (
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
)
]
@sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
......@@ -106,11 +111,11 @@ class TestKernels:
elif not all(batched_input.shape[:-data_dims]):
pytest.skip("Input has a degenerate batch shape.")
actual = info.kernel(batched_input, *other_args, **kwargs)
batched_output = info.kernel(batched_input, *other_args, **kwargs)
actual = self._unbatch(batched_output, data_dims=data_dims)
single_inputs = self._unbind_batch_dims(batched_input, data_dims=data_dims)
single_outputs = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
expected = self._stack_batch_dims(single_outputs)
single_inputs = self._unbatch(batched_input, data_dims=data_dims)
expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
assert_close(actual, expected, **info.closeness_kwargs)
......@@ -123,9 +128,9 @@ class TestKernels:
pytest.skip("The input has a degenerate shape.")
input_version = input._version
output = info.kernel(input, *other_args, **kwargs)
info.kernel(input, *other_args, **kwargs)
assert output is not input or output._version == input_version
assert input._version == input_version
@sample_inputs
@needs_cuda
......@@ -144,6 +149,9 @@ class TestKernels:
(input, *other_args), kwargs = args_kwargs.load(device)
output = info.kernel(input, *other_args, **kwargs)
# Most kernels just return a tensor, but some also return some additional metadata
if not isinstance(output, torch.Tensor):
output, *_ = output
assert output.dtype == input.dtype
assert output.device == input.device
......@@ -324,7 +332,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
affine_matrix = affine_matrix[:2, :]
image_size = bbox.image_size
height, width = bbox.image_size
bbox_xyxy = convert_format_bounding_box(
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
)
......@@ -336,9 +344,9 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
# image frame
[0.0, 0.0, 1.0],
[0.0, image_size[0], 1.0],
[image_size[1], image_size[0], 1.0],
[image_size[1], 0.0, 1.0],
[0.0, height, 1.0],
[width, height, 1.0],
[width, 0.0, 1.0],
]
)
transformed_points = np.matmul(points, affine_matrix.T)
......@@ -356,18 +364,21 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
out_bbox[2] -= tr_x
out_bbox[3] -= tr_y
# image_size should be updated, but it is OK here to skip its computation
# as we do not compute it in F.rotate_bounding_box
height = int(height - 2 * tr_y)
width = int(width - 2 * tr_x)
out_bbox = features.BoundingBox(
out_bbox,
format=features.BoundingBoxFormat.XYXY,
image_size=image_size,
image_size=(height, width),
dtype=bbox.dtype,
device=bbox.device,
)
return convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
return (
convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
),
(height, width),
)
image_size = (32, 38)
......@@ -376,7 +387,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size
output_bboxes = F.rotate_bounding_box(
output_bboxes, output_image_size = F.rotate_bounding_box(
bboxes,
bboxes_format,
image_size=bboxes_image_size,
......@@ -395,12 +406,14 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
expected_bboxes = []
for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center_))
expected_bbox, expected_image_size = _compute_expected_bbox(bbox, -angle, expand, center_)
expected_bboxes.append(expected_bbox)
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_bboxes, expected_bboxes, atol=1, rtol=0)
torch.testing.assert_close(output_image_size, expected_image_size, atol=1, rtol=0)
@pytest.mark.parametrize("device", cpu_and_gpu())
......@@ -445,7 +458,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
[18.36396103, 1.07968978, 46.64823228, 29.36396103],
]
output_boxes = F.rotate_bounding_box(
output_boxes, _ = F.rotate_bounding_box(
in_boxes,
in_boxes.format,
in_boxes.image_size,
......@@ -510,17 +523,20 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
if format != features.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format)
output_boxes = F.crop_bounding_box(
output_boxes, output_image_size = F.crop_bounding_box(
in_boxes,
format,
top,
left,
size[0],
size[1],
)
if format != features.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_image_size, size)
@pytest.mark.parametrize("device", cpu_and_gpu())
......@@ -585,12 +601,13 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
if format != features.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format)
output_boxes = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size)
output_boxes, output_image_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size)
if format != features.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_image_size, size)
def _parse_padding(padding):
......@@ -627,12 +644,21 @@ def test_correctness_pad_bounding_box(device, padding):
bbox = bbox.to(bbox_dtype)
return bbox
def _compute_expected_image_size(bbox, padding_):
pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
height, width = bbox.image_size
return height + pad_up + pad_down, width + pad_left + pad_right
for bboxes in make_bounding_boxes():
bboxes = bboxes.to(device)
bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size
output_boxes = F.pad_bounding_box(bboxes, format=bboxes_format, padding=padding)
output_boxes, output_image_size = F.pad_bounding_box(
bboxes, format=bboxes_format, image_size=bboxes_image_size, padding=padding
)
torch.testing.assert_close(output_image_size, _compute_expected_image_size(bboxes, padding))
if bboxes.ndim < 2 or bboxes.shape[0] == 0:
bboxes = [bboxes]
......@@ -781,7 +807,9 @@ def test_correctness_center_crop_bounding_box(device, output_size):
bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size
output_boxes = F.center_crop_bounding_box(bboxes, bboxes_format, bboxes_image_size, output_size)
output_boxes, output_image_size = F.center_crop_bounding_box(
bboxes, bboxes_format, bboxes_image_size, output_size
)
if bboxes.ndim < 2:
bboxes = [bboxes]
......@@ -796,6 +824,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
else:
expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_image_size, output_size)
@pytest.mark.parametrize("device", cpu_and_gpu())
......
......@@ -83,23 +83,19 @@ class BoundingBox(_Feature):
max_size: Optional[int] = None,
antialias: bool = False,
) -> BoundingBox:
output = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size)
if isinstance(size, int):
size = [size]
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
output, image_size = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size)
return BoundingBox.new_like(self, output, image_size=image_size)
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
output = self._F.crop_bounding_box(self, self.format, top, left)
return BoundingBox.new_like(self, output, image_size=(height, width))
output, image_size = self._F.crop_bounding_box(
self, self.format, top=top, left=left, height=height, width=width
)
return BoundingBox.new_like(self, output, image_size=image_size)
def center_crop(self, output_size: List[int]) -> BoundingBox:
output = self._F.center_crop_bounding_box(
output, image_size = self._F.center_crop_bounding_box(
self, format=self.format, image_size=self.image_size, output_size=output_size
)
if isinstance(output_size, int):
output_size = [output_size]
image_size = (output_size[0], output_size[0]) if len(output_size) == 1 else (output_size[0], output_size[1])
return BoundingBox.new_like(self, output, image_size=image_size)
def resized_crop(
......@@ -112,9 +108,8 @@ class BoundingBox(_Feature):
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False,
) -> BoundingBox:
output = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
output, image_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
return BoundingBox.new_like(self, output, image_size=image_size)
def pad(
self,
......@@ -122,19 +117,10 @@ class BoundingBox(_Feature):
fill: FillTypeJIT = None,
padding_mode: str = "constant",
) -> BoundingBox:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
output = self._F.pad_bounding_box(self, format=self.format, padding=padding, padding_mode=padding_mode)
# Update output image size:
left, right, top, bottom = self._F._geometry._parse_pad_padding(padding)
height, width = self.image_size
height += top + bottom
width += left + right
return BoundingBox.new_like(self, output, image_size=(height, width))
output, image_size = self._F.pad_bounding_box(
self, format=self.format, image_size=self.image_size, padding=padding, padding_mode=padding_mode
)
return BoundingBox.new_like(self, output, image_size=image_size)
def rotate(
self,
......@@ -144,23 +130,10 @@ class BoundingBox(_Feature):
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.rotate_bounding_box(
output, image_size = self._F.rotate_bounding_box(
self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center
)
image_size = self.image_size
if expand:
# The way we recompute image_size is not optimal due to redundant computations of
# - rotation matrix (_get_inverse_affine_matrix)
# - points dot matrix (_compute_affine_output_size)
# Alternatively, we could return new image size by self._F.rotate_bounding_box
height, width = image_size
rotation_matrix = self._F._geometry._get_inverse_affine_matrix(
[0.0, 0.0], angle, [0.0, 0.0], 1.0, [0.0, 0.0]
)
new_width, new_height = self._F._geometry._FT._compute_affine_output_size(rotation_matrix, width, height)
image_size = (new_height, new_width)
return BoundingBox.new_like(self, output, dtype=output.dtype, image_size=image_size)
return BoundingBox.new_like(self, output, image_size=image_size)
def affine(
self,
......
......@@ -7,7 +7,7 @@ import torch
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import (
_compute_resized_output_size,
_compute_resized_output_size as __compute_resized_output_size,
_get_inverse_affine_matrix,
InterpolationMode,
pil_modes_mapping,
......@@ -95,6 +95,14 @@ hflip = horizontal_flip
vflip = vertical_flip
def _compute_resized_output_size(
image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> List[int]:
if isinstance(size, int):
size = [size]
return __compute_resized_output_size(image_size, size=size, max_size=max_size)
def resize_image_tensor(
image: torch.Tensor,
size: List[int],
......@@ -102,8 +110,6 @@ def resize_image_tensor(
max_size: Optional[int] = None,
antialias: bool = False,
) -> torch.Tensor:
if isinstance(size, int):
size = [size]
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
extra_dims = image.shape[:-3]
......@@ -148,11 +154,7 @@ def resize_image_pil(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
) -> PIL.Image.Image:
if isinstance(size, int):
size = [size, size]
# Explicitly cast size to list otherwise mypy issue: incompatible type "Sequence[int]"; expected "List[int]"
size: List[int] = list(size)
size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size)
size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size) # type: ignore[arg-type]
return _FP.resize(image, size, interpolation=pil_modes_mapping[interpolation])
......@@ -173,13 +175,14 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
def resize_bounding_box(
bounding_box: torch.Tensor, image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> torch.Tensor:
if isinstance(size, int):
size = [size]
) -> Tuple[torch.Tensor, Tuple[int, int]]:
old_height, old_width = image_size
new_height, new_width = _compute_resized_output_size(image_size, size=size, max_size=max_size)
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return bounding_box.view(-1, 2, 2).mul(ratios).to(bounding_box.dtype).view(bounding_box.shape)
return (
bounding_box.view(-1, 2, 2).mul(ratios).to(bounding_box.dtype).view(bounding_box.shape),
(new_height, new_width),
)
def resize(
......@@ -545,7 +548,7 @@ def rotate_bounding_box(
angle: float,
expand: bool = False,
center: Optional[List[float]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Tuple[int, int]]:
if center is not None and expand:
warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None
......@@ -566,9 +569,20 @@ def rotate_bounding_box(
expand=expand,
)
return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)
if expand:
# TODO: Move this computation inside of `_affine_bounding_box_xyxy` to avoid computing the rotation and points
# matrix twice
height, width = image_size
rotation_matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, [0.0, 0.0], 1.0, [0.0, 0.0])
new_width, new_height = _FT._compute_affine_output_size(rotation_matrix, width, height)
image_size = (new_height, new_width)
return (
convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape),
image_size,
)
def rotate_mask(
......@@ -710,14 +724,15 @@ def pad_mask(
def pad_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
image_size: Tuple[int, int],
padding: Union[int, List[int]],
padding_mode: str = "constant",
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Tuple[int, int]]:
if padding_mode not in ["constant"]:
# TODO: add support of other padding modes
raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")
left, _, top, _ = _parse_pad_padding(padding)
left, right, top, bottom = _parse_pad_padding(padding)
bounding_box = bounding_box.clone()
......@@ -727,7 +742,12 @@ def pad_bounding_box(
if format == features.BoundingBoxFormat.XYXY:
bounding_box[..., 2] += left
bounding_box[..., 3] += top
return bounding_box
height, width = image_size
height += top + bottom
width += left + right
return bounding_box, (height, width)
def pad(
......@@ -754,7 +774,9 @@ def crop_bounding_box(
format: features.BoundingBoxFormat,
top: int,
left: int,
) -> torch.Tensor:
height: int,
width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
)
......@@ -763,8 +785,11 @@ def crop_bounding_box(
bounding_box[..., 0::2] -= left
bounding_box[..., 1::2] -= top
return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
return (
convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
),
(height, width),
)
......@@ -1081,10 +1106,10 @@ def center_crop_bounding_box(
format: features.BoundingBoxFormat,
image_size: Tuple[int, int],
output_size: List[int],
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Tuple[int, int]]:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size)
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left)
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width)
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
......@@ -1147,8 +1172,8 @@ def resized_crop_bounding_box(
height: int,
width: int,
size: List[int],
) -> torch.Tensor:
bounding_box = crop_bounding_box(bounding_box, format, top, left)
) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width)
return resize_bounding_box(bounding_box, (height, width), size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment