Unverified Commit ae83c9fd authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

[PoC] move metadata computation from prototype features into kernels (#6646)

* move metadata computation from prototype features into kernels

* fix tests

* fix no_inplace test

* mypy

* add perf TODO
parent 2907c494
...@@ -709,7 +709,7 @@ def sample_inputs_crop_bounding_box(): ...@@ -709,7 +709,7 @@ def sample_inputs_crop_bounding_box():
for bounding_box_loader, params in itertools.product( for bounding_box_loader, params in itertools.product(
make_bounding_box_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]] make_bounding_box_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
): ):
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, top=params["top"], left=params["left"]) yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)
def sample_inputs_crop_mask(): def sample_inputs_crop_mask():
...@@ -856,7 +856,9 @@ def sample_inputs_pad_bounding_box(): ...@@ -856,7 +856,9 @@ def sample_inputs_pad_bounding_box():
if params["padding_mode"] != "constant": if params["padding_mode"] != "constant":
continue continue
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params) yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size, **params
)
def sample_inputs_pad_mask(): def sample_inputs_pad_mask():
...@@ -1552,8 +1554,6 @@ KERNEL_INFOS.extend( ...@@ -1552,8 +1554,6 @@ KERNEL_INFOS.extend(
skips=[ skips=[
skip_integer_size_jit(), skip_integer_size_jit(),
Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."), Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."),
Skip("test_no_inplace", reason="Output of five_crop_image_tensor is not a tensor."),
Skip("test_dtype_and_device_consistency", reason="Output of five_crop_image_tensor is not a tensor."),
], ],
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
), ),
...@@ -1565,8 +1565,6 @@ KERNEL_INFOS.extend( ...@@ -1565,8 +1565,6 @@ KERNEL_INFOS.extend(
skips=[ skips=[
skip_integer_size_jit(), skip_integer_size_jit(),
Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."), Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."),
Skip("test_no_inplace", reason="Output of ten_crop_image_tensor is not a tensor."),
Skip("test_dtype_and_device_consistency", reason="Output of ten_crop_image_tensor is not a tensor."),
], ],
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
), ),
......
...@@ -68,17 +68,22 @@ class TestKernels: ...@@ -68,17 +68,22 @@ class TestKernels:
assert_close(actual, expected, **info.closeness_kwargs) assert_close(actual, expected, **info.closeness_kwargs)
def _unbind_batch_dims(self, batched_tensor, *, data_dims): def _unbatch(self, batch, *, data_dims):
if batched_tensor.ndim == data_dims: if isinstance(batch, torch.Tensor):
return batched_tensor batched_tensor = batch
metadata = ()
return [self._unbind_batch_dims(t, data_dims=data_dims) for t in batched_tensor.unbind(0)] else:
batched_tensor, *metadata = batch
def _stack_batch_dims(self, unbound_tensor): if batched_tensor.ndim == data_dims:
if isinstance(unbound_tensor[0], torch.Tensor): return batch
return torch.stack(unbound_tensor)
return torch.stack([self._stack_batch_dims(t) for t in unbound_tensor]) return [
self._unbatch(unbatched, data_dims=data_dims)
for unbatched in (
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
)
]
@sample_inputs @sample_inputs
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -106,11 +111,11 @@ class TestKernels: ...@@ -106,11 +111,11 @@ class TestKernels:
elif not all(batched_input.shape[:-data_dims]): elif not all(batched_input.shape[:-data_dims]):
pytest.skip("Input has a degenerate batch shape.") pytest.skip("Input has a degenerate batch shape.")
actual = info.kernel(batched_input, *other_args, **kwargs) batched_output = info.kernel(batched_input, *other_args, **kwargs)
actual = self._unbatch(batched_output, data_dims=data_dims)
single_inputs = self._unbind_batch_dims(batched_input, data_dims=data_dims) single_inputs = self._unbatch(batched_input, data_dims=data_dims)
single_outputs = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs) expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
expected = self._stack_batch_dims(single_outputs)
assert_close(actual, expected, **info.closeness_kwargs) assert_close(actual, expected, **info.closeness_kwargs)
...@@ -123,9 +128,9 @@ class TestKernels: ...@@ -123,9 +128,9 @@ class TestKernels:
pytest.skip("The input has a degenerate shape.") pytest.skip("The input has a degenerate shape.")
input_version = input._version input_version = input._version
output = info.kernel(input, *other_args, **kwargs) info.kernel(input, *other_args, **kwargs)
assert output is not input or output._version == input_version assert input._version == input_version
@sample_inputs @sample_inputs
@needs_cuda @needs_cuda
...@@ -144,6 +149,9 @@ class TestKernels: ...@@ -144,6 +149,9 @@ class TestKernels:
(input, *other_args), kwargs = args_kwargs.load(device) (input, *other_args), kwargs = args_kwargs.load(device)
output = info.kernel(input, *other_args, **kwargs) output = info.kernel(input, *other_args, **kwargs)
# Most kernels just return a tensor, but some also return some additional metadata
if not isinstance(output, torch.Tensor):
output, *_ = output
assert output.dtype == input.dtype assert output.dtype == input.dtype
assert output.device == input.device assert output.device == input.device
...@@ -324,7 +332,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -324,7 +332,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
affine_matrix = affine_matrix[:2, :] affine_matrix = affine_matrix[:2, :]
image_size = bbox.image_size height, width = bbox.image_size
bbox_xyxy = convert_format_bounding_box( bbox_xyxy = convert_format_bounding_box(
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
) )
...@@ -336,9 +344,9 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -336,9 +344,9 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
# image frame # image frame
[0.0, 0.0, 1.0], [0.0, 0.0, 1.0],
[0.0, image_size[0], 1.0], [0.0, height, 1.0],
[image_size[1], image_size[0], 1.0], [width, height, 1.0],
[image_size[1], 0.0, 1.0], [width, 0.0, 1.0],
] ]
) )
transformed_points = np.matmul(points, affine_matrix.T) transformed_points = np.matmul(points, affine_matrix.T)
...@@ -356,18 +364,21 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -356,18 +364,21 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
out_bbox[2] -= tr_x out_bbox[2] -= tr_x
out_bbox[3] -= tr_y out_bbox[3] -= tr_y
# image_size should be updated, but it is OK here to skip its computation height = int(height - 2 * tr_y)
# as we do not compute it in F.rotate_bounding_box width = int(width - 2 * tr_x)
out_bbox = features.BoundingBox( out_bbox = features.BoundingBox(
out_bbox, out_bbox,
format=features.BoundingBoxFormat.XYXY, format=features.BoundingBoxFormat.XYXY,
image_size=image_size, image_size=(height, width),
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return convert_format_bounding_box( return (
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
),
(height, width),
) )
image_size = (32, 38) image_size = (32, 38)
...@@ -376,7 +387,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -376,7 +387,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_image_size = bboxes.image_size
output_bboxes = F.rotate_bounding_box( output_bboxes, output_image_size = F.rotate_bounding_box(
bboxes, bboxes,
bboxes_format, bboxes_format,
image_size=bboxes_image_size, image_size=bboxes_image_size,
...@@ -395,12 +406,14 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -395,12 +406,14 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center_)) expected_bbox, expected_image_size = _compute_expected_bbox(bbox, -angle, expand, center_)
expected_bboxes.append(expected_bbox)
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
else: else:
expected_bboxes = expected_bboxes[0] expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_bboxes, expected_bboxes, atol=1, rtol=0) torch.testing.assert_close(output_bboxes, expected_bboxes, atol=1, rtol=0)
torch.testing.assert_close(output_image_size, expected_image_size, atol=1, rtol=0)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -445,7 +458,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ...@@ -445,7 +458,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
[18.36396103, 1.07968978, 46.64823228, 29.36396103], [18.36396103, 1.07968978, 46.64823228, 29.36396103],
] ]
output_boxes = F.rotate_bounding_box( output_boxes, _ = F.rotate_bounding_box(
in_boxes, in_boxes,
in_boxes.format, in_boxes.format,
in_boxes.image_size, in_boxes.image_size,
...@@ -510,17 +523,20 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -510,17 +523,20 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
if format != features.BoundingBoxFormat.XYXY: if format != features.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format)
output_boxes = F.crop_bounding_box( output_boxes, output_image_size = F.crop_bounding_box(
in_boxes, in_boxes,
format, format,
top, top,
left, left,
size[0],
size[1],
) )
if format != features.BoundingBoxFormat.XYXY: if format != features.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
torch.testing.assert_close(output_image_size, size)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
...@@ -585,12 +601,13 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height ...@@ -585,12 +601,13 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
if format != features.BoundingBoxFormat.XYXY: if format != features.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format)
output_boxes = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size) output_boxes, output_image_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size)
if format != features.BoundingBoxFormat.XYXY: if format != features.BoundingBoxFormat.XYXY:
output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY) output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_image_size, size)
def _parse_padding(padding): def _parse_padding(padding):
...@@ -627,12 +644,21 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -627,12 +644,21 @@ def test_correctness_pad_bounding_box(device, padding):
bbox = bbox.to(bbox_dtype) bbox = bbox.to(bbox_dtype)
return bbox return bbox
def _compute_expected_image_size(bbox, padding_):
pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
height, width = bbox.image_size
return height + pad_up + pad_down, width + pad_left + pad_right
for bboxes in make_bounding_boxes(): for bboxes in make_bounding_boxes():
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_image_size = bboxes.image_size
output_boxes = F.pad_bounding_box(bboxes, format=bboxes_format, padding=padding) output_boxes, output_image_size = F.pad_bounding_box(
bboxes, format=bboxes_format, image_size=bboxes_image_size, padding=padding
)
torch.testing.assert_close(output_image_size, _compute_expected_image_size(bboxes, padding))
if bboxes.ndim < 2 or bboxes.shape[0] == 0: if bboxes.ndim < 2 or bboxes.shape[0] == 0:
bboxes = [bboxes] bboxes = [bboxes]
...@@ -781,7 +807,9 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -781,7 +807,9 @@ def test_correctness_center_crop_bounding_box(device, output_size):
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_image_size = bboxes.image_size
output_boxes = F.center_crop_bounding_box(bboxes, bboxes_format, bboxes_image_size, output_size) output_boxes, output_image_size = F.center_crop_bounding_box(
bboxes, bboxes_format, bboxes_image_size, output_size
)
if bboxes.ndim < 2: if bboxes.ndim < 2:
bboxes = [bboxes] bboxes = [bboxes]
...@@ -796,6 +824,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -796,6 +824,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
else: else:
expected_bboxes = expected_bboxes[0] expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_image_size, output_size)
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
......
...@@ -83,23 +83,19 @@ class BoundingBox(_Feature): ...@@ -83,23 +83,19 @@ class BoundingBox(_Feature):
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: bool = False, antialias: bool = False,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size) output, image_size = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size)
if isinstance(size, int): return BoundingBox.new_like(self, output, image_size=image_size)
size = [size]
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
output = self._F.crop_bounding_box(self, self.format, top, left) output, image_size = self._F.crop_bounding_box(
return BoundingBox.new_like(self, output, image_size=(height, width)) self, self.format, top=top, left=left, height=height, width=width
)
return BoundingBox.new_like(self, output, image_size=image_size)
def center_crop(self, output_size: List[int]) -> BoundingBox: def center_crop(self, output_size: List[int]) -> BoundingBox:
output = self._F.center_crop_bounding_box( output, image_size = self._F.center_crop_bounding_box(
self, format=self.format, image_size=self.image_size, output_size=output_size self, format=self.format, image_size=self.image_size, output_size=output_size
) )
if isinstance(output_size, int):
output_size = [output_size]
image_size = (output_size[0], output_size[0]) if len(output_size) == 1 else (output_size[0], output_size[1])
return BoundingBox.new_like(self, output, image_size=image_size) return BoundingBox.new_like(self, output, image_size=image_size)
def resized_crop( def resized_crop(
...@@ -112,9 +108,8 @@ class BoundingBox(_Feature): ...@@ -112,9 +108,8 @@ class BoundingBox(_Feature):
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False, antialias: bool = False,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size) output, image_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1]) return BoundingBox.new_like(self, output, image_size=image_size)
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
def pad( def pad(
self, self,
...@@ -122,19 +117,10 @@ class BoundingBox(_Feature): ...@@ -122,19 +117,10 @@ class BoundingBox(_Feature):
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> BoundingBox: ) -> BoundingBox:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy output, image_size = self._F.pad_bounding_box(
if not isinstance(padding, int): self, format=self.format, image_size=self.image_size, padding=padding, padding_mode=padding_mode
padding = list(padding) )
return BoundingBox.new_like(self, output, image_size=image_size)
output = self._F.pad_bounding_box(self, format=self.format, padding=padding, padding_mode=padding_mode)
# Update output image size:
left, right, top, bottom = self._F._geometry._parse_pad_padding(padding)
height, width = self.image_size
height += top + bottom
width += left + right
return BoundingBox.new_like(self, output, image_size=(height, width))
def rotate( def rotate(
self, self,
...@@ -144,23 +130,10 @@ class BoundingBox(_Feature): ...@@ -144,23 +130,10 @@ class BoundingBox(_Feature):
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.rotate_bounding_box( output, image_size = self._F.rotate_bounding_box(
self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center
) )
image_size = self.image_size return BoundingBox.new_like(self, output, image_size=image_size)
if expand:
# The way we recompute image_size is not optimal due to redundant computations of
# - rotation matrix (_get_inverse_affine_matrix)
# - points dot matrix (_compute_affine_output_size)
# Alternatively, we could return new image size by self._F.rotate_bounding_box
height, width = image_size
rotation_matrix = self._F._geometry._get_inverse_affine_matrix(
[0.0, 0.0], angle, [0.0, 0.0], 1.0, [0.0, 0.0]
)
new_width, new_height = self._F._geometry._FT._compute_affine_output_size(rotation_matrix, width, height)
image_size = (new_height, new_width)
return BoundingBox.new_like(self, output, dtype=output.dtype, image_size=image_size)
def affine( def affine(
self, self,
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import ( from torchvision.transforms.functional import (
_compute_resized_output_size, _compute_resized_output_size as __compute_resized_output_size,
_get_inverse_affine_matrix, _get_inverse_affine_matrix,
InterpolationMode, InterpolationMode,
pil_modes_mapping, pil_modes_mapping,
...@@ -95,6 +95,14 @@ hflip = horizontal_flip ...@@ -95,6 +95,14 @@ hflip = horizontal_flip
vflip = vertical_flip vflip = vertical_flip
def _compute_resized_output_size(
image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> List[int]:
if isinstance(size, int):
size = [size]
return __compute_resized_output_size(image_size, size=size, max_size=max_size)
def resize_image_tensor( def resize_image_tensor(
image: torch.Tensor, image: torch.Tensor,
size: List[int], size: List[int],
...@@ -102,8 +110,6 @@ def resize_image_tensor( ...@@ -102,8 +110,6 @@ def resize_image_tensor(
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: bool = False, antialias: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if isinstance(size, int):
size = [size]
num_channels, old_height, old_width = get_dimensions_image_tensor(image) num_channels, old_height, old_width = get_dimensions_image_tensor(image)
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
extra_dims = image.shape[:-3] extra_dims = image.shape[:-3]
...@@ -148,11 +154,7 @@ def resize_image_pil( ...@@ -148,11 +154,7 @@ def resize_image_pil(
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
if isinstance(size, int): size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size) # type: ignore[arg-type]
size = [size, size]
# Explicitly cast size to list otherwise mypy issue: incompatible type "Sequence[int]"; expected "List[int]"
size: List[int] = list(size)
size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size)
return _FP.resize(image, size, interpolation=pil_modes_mapping[interpolation]) return _FP.resize(image, size, interpolation=pil_modes_mapping[interpolation])
...@@ -173,13 +175,14 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N ...@@ -173,13 +175,14 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
def resize_bounding_box( def resize_bounding_box(
bounding_box: torch.Tensor, image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None bounding_box: torch.Tensor, image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> torch.Tensor: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
if isinstance(size, int):
size = [size]
old_height, old_width = image_size old_height, old_width = image_size
new_height, new_width = _compute_resized_output_size(image_size, size=size, max_size=max_size) new_height, new_width = _compute_resized_output_size(image_size, size=size, max_size=max_size)
ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device)
return bounding_box.view(-1, 2, 2).mul(ratios).to(bounding_box.dtype).view(bounding_box.shape) return (
bounding_box.view(-1, 2, 2).mul(ratios).to(bounding_box.dtype).view(bounding_box.shape),
(new_height, new_width),
)
def resize( def resize(
...@@ -545,7 +548,7 @@ def rotate_bounding_box( ...@@ -545,7 +548,7 @@ def rotate_bounding_box(
angle: float, angle: float,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
if center is not None and expand: if center is not None and expand:
warnings.warn("The provided center argument has no effect on the result if expand is True") warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None center = None
...@@ -566,9 +569,20 @@ def rotate_bounding_box( ...@@ -566,9 +569,20 @@ def rotate_bounding_box(
expand=expand, expand=expand,
) )
return convert_format_bounding_box( if expand:
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False # TODO: Move this computation inside of `_affine_bounding_box_xyxy` to avoid computing the rotation and points
).view(original_shape) # matrix twice
height, width = image_size
rotation_matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, [0.0, 0.0], 1.0, [0.0, 0.0])
new_width, new_height = _FT._compute_affine_output_size(rotation_matrix, width, height)
image_size = (new_height, new_width)
return (
convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape),
image_size,
)
def rotate_mask( def rotate_mask(
...@@ -710,14 +724,15 @@ def pad_mask( ...@@ -710,14 +724,15 @@ def pad_mask(
def pad_bounding_box( def pad_bounding_box(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
format: features.BoundingBoxFormat, format: features.BoundingBoxFormat,
image_size: Tuple[int, int],
padding: Union[int, List[int]], padding: Union[int, List[int]],
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
if padding_mode not in ["constant"]: if padding_mode not in ["constant"]:
# TODO: add support of other padding modes # TODO: add support of other padding modes
raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes") raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes")
left, _, top, _ = _parse_pad_padding(padding) left, right, top, bottom = _parse_pad_padding(padding)
bounding_box = bounding_box.clone() bounding_box = bounding_box.clone()
...@@ -727,7 +742,12 @@ def pad_bounding_box( ...@@ -727,7 +742,12 @@ def pad_bounding_box(
if format == features.BoundingBoxFormat.XYXY: if format == features.BoundingBoxFormat.XYXY:
bounding_box[..., 2] += left bounding_box[..., 2] += left
bounding_box[..., 3] += top bounding_box[..., 3] += top
return bounding_box
height, width = image_size
height += top + bottom
width += left + right
return bounding_box, (height, width)
def pad( def pad(
...@@ -754,7 +774,9 @@ def crop_bounding_box( ...@@ -754,7 +774,9 @@ def crop_bounding_box(
format: features.BoundingBoxFormat, format: features.BoundingBoxFormat,
top: int, top: int,
left: int, left: int,
) -> torch.Tensor: height: int,
width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_box = convert_format_bounding_box( bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
) )
...@@ -763,8 +785,11 @@ def crop_bounding_box( ...@@ -763,8 +785,11 @@ def crop_bounding_box(
bounding_box[..., 0::2] -= left bounding_box[..., 0::2] -= left
bounding_box[..., 1::2] -= top bounding_box[..., 1::2] -= top
return convert_format_bounding_box( return (
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
),
(height, width),
) )
...@@ -1081,10 +1106,10 @@ def center_crop_bounding_box( ...@@ -1081,10 +1106,10 @@ def center_crop_bounding_box(
format: features.BoundingBoxFormat, format: features.BoundingBoxFormat,
image_size: Tuple[int, int], image_size: Tuple[int, int],
output_size: List[int], output_size: List[int],
) -> torch.Tensor: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size)
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left) return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width)
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
...@@ -1147,8 +1172,8 @@ def resized_crop_bounding_box( ...@@ -1147,8 +1172,8 @@ def resized_crop_bounding_box(
height: int, height: int,
width: int, width: int,
size: List[int], size: List[int],
) -> torch.Tensor: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_box = crop_bounding_box(bounding_box, format, top, left) bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width)
return resize_bounding_box(bounding_box, (height, width), size) return resize_bounding_box(bounding_box, (height, width), size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment