Unverified Commit c5958862 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fix bug on prototype `pad` (#6949)

parent deba0562
...@@ -727,39 +727,38 @@ def _pad_with_scalar_fill( ...@@ -727,39 +727,38 @@ def _pad_with_scalar_fill(
shape = image.shape shape = image.shape
num_channels, height, width = shape[-3:] num_channels, height, width = shape[-3:]
if image.numel() > 0: batch_size = 1
image = image.reshape(-1, num_channels, height, width) for s in shape[:-3]:
batch_size *= s
if padding_mode == "edge":
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map image = image.reshape(batch_size, num_channels, height, width)
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
# name. if padding_mode == "edge":
padding_mode = "replicate" # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
if padding_mode == "constant": # name.
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill)) padding_mode = "replicate"
elif padding_mode in ("reflect", "replicate"):
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs. if padding_mode == "constant":
# TODO: See https://github.com/pytorch/pytorch/issues/40763 image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
dtype = image.dtype elif padding_mode in ("reflect", "replicate"):
if not image.is_floating_point(): # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
needs_cast = True # TODO: See https://github.com/pytorch/pytorch/issues/40763
image = image.to(torch.float32) dtype = image.dtype
else: if not image.is_floating_point():
needs_cast = False needs_cast = True
image = image.to(torch.float32)
image = torch_pad(image, torch_padding, mode=padding_mode) else:
needs_cast = False
if needs_cast:
image = image.to(dtype)
else: # padding_mode == "symmetric"
image = _FT._pad_symmetric(image, torch_padding)
new_height, new_width = image.shape[-2:] image = torch_pad(image, torch_padding, mode=padding_mode)
else:
left, right, top, bottom = torch_padding if needs_cast:
new_height = height + top + bottom image = image.to(dtype)
new_width = width + left + right else: # padding_mode == "symmetric"
image = _FT._pad_symmetric(image, torch_padding)
new_height, new_width = image.shape[-2:]
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
...@@ -868,7 +867,24 @@ def pad( ...@@ -868,7 +867,24 @@ def pad(
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
crop_image_tensor = _FT.crop def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
h, w = image.shape[-2:]
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > w or bottom > h:
image = image[..., max(top, 0) : bottom, max(left, 0) : right]
torch_padding = [
max(min(right, 0) - left, 0),
max(right - max(w, left), 0),
max(min(bottom, 0) - top, 0),
max(bottom - max(h, top), 0),
]
return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
return image[..., top:bottom, left:right]
crop_image_pil = _FP.crop crop_image_pil = _FP.crop
...@@ -893,7 +909,18 @@ def crop_bounding_box( ...@@ -893,7 +909,18 @@ def crop_bounding_box(
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(mask, top, left, height, width) if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = crop_image_tensor(mask, top, left, height, width)
if needs_squeeze:
output = output.squeeze(0)
return output
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment