Unverified Commit c5958862 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fix bug on prototype `pad` (#6949)

parent deba0562
...@@ -727,8 +727,11 @@ def _pad_with_scalar_fill( ...@@ -727,8 +727,11 @@ def _pad_with_scalar_fill(
shape = image.shape shape = image.shape
num_channels, height, width = shape[-3:] num_channels, height, width = shape[-3:]
if image.numel() > 0: batch_size = 1
image = image.reshape(-1, num_channels, height, width) for s in shape[:-3]:
batch_size *= s
image = image.reshape(batch_size, num_channels, height, width)
if padding_mode == "edge": if padding_mode == "edge":
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
...@@ -756,10 +759,6 @@ def _pad_with_scalar_fill( ...@@ -756,10 +759,6 @@ def _pad_with_scalar_fill(
image = _FT._pad_symmetric(image, torch_padding) image = _FT._pad_symmetric(image, torch_padding)
new_height, new_width = image.shape[-2:] new_height, new_width = image.shape[-2:]
else:
left, right, top, bottom = torch_padding
new_height = height + top + bottom
new_width = width + left + right
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
...@@ -868,7 +867,24 @@ def pad( ...@@ -868,7 +867,24 @@ def pad(
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode) return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
crop_image_tensor = _FT.crop def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
h, w = image.shape[-2:]
right = left + width
bottom = top + height
if left < 0 or top < 0 or right > w or bottom > h:
image = image[..., max(top, 0) : bottom, max(left, 0) : right]
torch_padding = [
max(min(right, 0) - left, 0),
max(right - max(w, left), 0),
max(min(bottom, 0) - top, 0),
max(bottom - max(h, top), 0),
]
return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
return image[..., top:bottom, left:right]
crop_image_pil = _FP.crop crop_image_pil = _FP.crop
...@@ -893,7 +909,18 @@ def crop_bounding_box( ...@@ -893,7 +909,18 @@ def crop_bounding_box(
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(mask, top, left, height, width) if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = crop_image_tensor(mask, top, left, height, width)
if needs_squeeze:
output = output.squeeze(0)
return output
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment