Unverified Commit 4d085f2e authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

remove unnecessary checks from pad_image_tensor (#6894)

* remove unnecessary changes from pad_image_tensor

* cleanup

* fix fill=None workaround

* address review comments

* remove more xfails
parent f1b840d5
......@@ -234,7 +234,6 @@ DISPATCHER_INFOS = [
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
),
xfail_jit_python_scalar_arg("padding"),
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
......
......@@ -1146,7 +1146,6 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_pad_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=[
xfail_jit_python_scalar_arg("padding"),
xfail_jit_tuple_instead_of_list("padding"),
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
......@@ -1159,7 +1158,6 @@ KERNEL_INFOS.extend(
reference_fn=reference_pad_bounding_box,
reference_inputs_fn=reference_inputs_pad_bounding_box,
test_marks=[
xfail_jit_python_scalar_arg("padding"),
xfail_jit_tuple_instead_of_list("padding"),
],
),
......
......@@ -4,7 +4,8 @@ from typing import List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
from torch.nn.functional import interpolate
from torch.nn.functional import interpolate, pad as torch_pad
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import (
......@@ -15,7 +16,6 @@ from torchvision.transforms.functional import (
pil_to_tensor,
to_pil_image,
)
from torchvision.transforms.functional_tensor import _parse_pad_padding
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
......@@ -663,7 +663,28 @@ def rotate(
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
pad_image_pil = _FP.pad
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
elif isinstance(padding, (tuple, list)):
if len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
elif len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
else:
raise ValueError(
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
)
else:
raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")
return [pad_left, pad_right, pad_top, pad_bottom]
def pad_image_tensor(
......@@ -672,50 +693,86 @@ def pad_image_tensor(
fill: features.FillTypeJIT = None,
padding_mode: str = "constant",
) -> torch.Tensor:
# Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
# `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
# internally.
torch_padding = _parse_pad_padding(padding)
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError(
f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
f"but got `'{padding_mode}'`."
)
if fill is None:
# This is a JIT workaround
return _pad_with_scalar_fill(image, padding, fill=None, padding_mode=padding_mode)
elif isinstance(fill, (int, float)) or len(fill) == 1:
fill_number = fill[0] if isinstance(fill, list) else fill
return _pad_with_scalar_fill(image, padding, fill=fill_number, padding_mode=padding_mode)
fill = 0
if isinstance(fill, (int, float)):
return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
elif len(fill) == 1:
return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
else:
return _pad_with_vector_fill(image, padding, fill=fill, padding_mode=padding_mode)
return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
def _pad_with_scalar_fill(
image: torch.Tensor,
padding: Union[int, List[int]],
fill: Union[int, float, None],
padding_mode: str = "constant",
torch_padding: List[int],
fill: Union[int, float],
padding_mode: str,
) -> torch.Tensor:
shape = image.shape
num_channels, height, width = shape[-3:]
if image.numel() > 0:
image = _FT.pad(
img=image.reshape(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
)
image = image.reshape(-1, num_channels, height, width)
if padding_mode == "edge":
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
# name.
padding_mode = "replicate"
if padding_mode == "constant":
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
elif padding_mode in ("reflect", "replicate"):
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
# TODO: See https://github.com/pytorch/pytorch/issues/40763
dtype = image.dtype
if not image.is_floating_point():
needs_cast = True
image = image.to(torch.float32)
else:
needs_cast = False
image = torch_pad(image, torch_padding, mode=padding_mode)
if needs_cast:
image = image.to(dtype)
else: # padding_mode == "symmetric"
image = _FT._pad_symmetric(image, torch_padding)
new_height, new_width = image.shape[-2:]
else:
left, right, top, bottom = _FT._parse_pad_padding(padding)
left, right, top, bottom = torch_padding
new_height = height + top + bottom
new_width = width + left + right
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
# TODO: This should be removed once pytorch pad supports non-scalar padding values
# TODO: This should be removed once torch_pad supports non-scalar padding values
def _pad_with_vector_fill(
image: torch.Tensor,
padding: Union[int, List[int]],
torch_padding: List[int],
fill: List[float],
padding_mode: str = "constant",
padding_mode: str,
) -> torch.Tensor:
if padding_mode != "constant":
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant")
left, right, top, bottom = _parse_pad_padding(padding)
output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
left, right, top, bottom = torch_padding
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)
if top > 0:
......@@ -729,6 +786,9 @@ def _pad_with_vector_fill(
return output
pad_image_pil = _FP.pad
def pad_mask(
mask: torch.Tensor,
padding: Union[int, List[int]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment