Unverified Commit 4d085f2e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove unnecessary checks from pad_image_tensor (#6894)

* remove unnecessary changes from pad_image_tensor

* cleanup

* fix fill=None workaround

* address review comments

* remove more xfails
parent f1b840d5
...@@ -234,7 +234,6 @@ DISPATCHER_INFOS = [ ...@@ -234,7 +234,6 @@ DISPATCHER_INFOS = [
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs) condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant", and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
), ),
xfail_jit_python_scalar_arg("padding"),
xfail_jit_tuple_instead_of_list("padding"), xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok # TODO: check if this is a regression since it seems that should be supported if `int` is ok
......
...@@ -1146,7 +1146,6 @@ KERNEL_INFOS.extend( ...@@ -1146,7 +1146,6 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_pad_image_tensor, reference_inputs_fn=reference_inputs_pad_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("padding"),
xfail_jit_tuple_instead_of_list("padding"), xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok # TODO: check if this is a regression since it seems that should be supported if `int` is ok
...@@ -1159,7 +1158,6 @@ KERNEL_INFOS.extend( ...@@ -1159,7 +1158,6 @@ KERNEL_INFOS.extend(
reference_fn=reference_pad_bounding_box, reference_fn=reference_pad_bounding_box,
reference_inputs_fn=reference_inputs_pad_bounding_box, reference_inputs_fn=reference_inputs_pad_bounding_box,
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("padding"),
xfail_jit_tuple_instead_of_list("padding"), xfail_jit_tuple_instead_of_list("padding"),
], ],
), ),
......
...@@ -4,7 +4,8 @@ from typing import List, Optional, Sequence, Tuple, Union ...@@ -4,7 +4,8 @@ from typing import List, Optional, Sequence, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.nn.functional import interpolate from torch.nn.functional import interpolate, pad as torch_pad
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import ( from torchvision.transforms.functional import (
...@@ -15,7 +16,6 @@ from torchvision.transforms.functional import ( ...@@ -15,7 +16,6 @@ from torchvision.transforms.functional import (
pil_to_tensor, pil_to_tensor,
to_pil_image, to_pil_image,
) )
from torchvision.transforms.functional_tensor import _parse_pad_padding
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
...@@ -663,7 +663,28 @@ def rotate( ...@@ -663,7 +663,28 @@ def rotate(
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
pad_image_pil = _FP.pad def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif isinstance(padding, (tuple, list)):
if len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
elif len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
else:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
else:
raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")
return [pad_left, pad_right, pad_top, pad_bottom]
def pad_image_tensor( def pad_image_tensor(
...@@ -672,50 +693,86 @@ def pad_image_tensor( ...@@ -672,50 +693,86 @@ def pad_image_tensor(
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
# Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
# `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
# internally.
torch_padding = _parse_pad_padding(padding)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError(
f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
f"but got `'{padding_mode}'`."
)
if fill is None: if fill is None:
# This is a JIT workaround fill = 0
return _pad_with_scalar_fill(image, padding, fill=None, padding_mode=padding_mode)
elif isinstance(fill, (int, float)) or len(fill) == 1: if isinstance(fill, (int, float)):
fill_number = fill[0] if isinstance(fill, list) else fill return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
return _pad_with_scalar_fill(image, padding, fill=fill_number, padding_mode=padding_mode) elif len(fill) == 1:
return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
else: else:
return _pad_with_vector_fill(image, padding, fill=fill, padding_mode=padding_mode) return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
def _pad_with_scalar_fill( def _pad_with_scalar_fill(
image: torch.Tensor, image: torch.Tensor,
padding: Union[int, List[int]], torch_padding: List[int],
fill: Union[int, float, None], fill: Union[int, float],
padding_mode: str = "constant", padding_mode: str,
) -> torch.Tensor: ) -> torch.Tensor:
shape = image.shape shape = image.shape
num_channels, height, width = shape[-3:] num_channels, height, width = shape[-3:]
if image.numel() > 0: if image.numel() > 0:
image = _FT.pad( image = image.reshape(-1, num_channels, height, width)
img=image.reshape(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
) if padding_mode == "edge":
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
# name.
padding_mode = "replicate"
if padding_mode == "constant":
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
elif padding_mode in ("reflect", "replicate"):
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
# TODO: See https://github.com/pytorch/pytorch/issues/40763
dtype = image.dtype
if not image.is_floating_point():
needs_cast = True
image = image.to(torch.float32)
else:
needs_cast = False
image = torch_pad(image, torch_padding, mode=padding_mode)
if needs_cast:
image = image.to(dtype)
else: # padding_mode == "symmetric"
image = _FT._pad_symmetric(image, torch_padding)
new_height, new_width = image.shape[-2:] new_height, new_width = image.shape[-2:]
else: else:
left, right, top, bottom = _FT._parse_pad_padding(padding) left, right, top, bottom = torch_padding
new_height = height + top + bottom new_height = height + top + bottom
new_width = width + left + right new_width = width + left + right
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
# TODO: This should be removed once pytorch pad supports non-scalar padding values # TODO: This should be removed once torch_pad supports non-scalar padding values
def _pad_with_vector_fill( def _pad_with_vector_fill(
image: torch.Tensor, image: torch.Tensor,
padding: Union[int, List[int]], torch_padding: List[int],
fill: List[float], fill: List[float],
padding_mode: str = "constant", padding_mode: str,
) -> torch.Tensor: ) -> torch.Tensor:
if padding_mode != "constant": if padding_mode != "constant":
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant") output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
left, right, top, bottom = _parse_pad_padding(padding) left, right, top, bottom = torch_padding
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1) fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)
if top > 0: if top > 0:
...@@ -729,6 +786,9 @@ def _pad_with_vector_fill( ...@@ -729,6 +786,9 @@ def _pad_with_vector_fill(
return output return output
pad_image_pil = _FP.pad
def pad_mask( def pad_mask(
mask: torch.Tensor, mask: torch.Tensor,
padding: Union[int, List[int]], padding: Union[int, List[int]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment