"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a4b233e5b5092e8ff861b5f5d3ac646fcba9ba79"
Unverified Commit 30b879fc authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Cleanup prototype kernel signatures (#6648)

* pass metadata directly after input in prototype kernels

* rename img to image
parent dc07ac2a
...@@ -632,7 +632,7 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -632,7 +632,7 @@ def test_correctness_pad_bounding_box(device, padding):
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_image_size = bboxes.image_size
output_boxes = F.pad_bounding_box(bboxes, padding, format=bboxes_format) output_boxes = F.pad_bounding_box(bboxes, format=bboxes_format, padding=padding)
if bboxes.ndim < 2 or bboxes.shape[0] == 0: if bboxes.ndim < 2 or bboxes.shape[0] == 0:
bboxes = [bboxes] bboxes = [bboxes]
...@@ -781,7 +781,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -781,7 +781,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
bboxes_format = bboxes.format bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size bboxes_image_size = bboxes.image_size
output_boxes = F.center_crop_bounding_box(bboxes, bboxes_format, output_size, bboxes_image_size) output_boxes = F.center_crop_bounding_box(bboxes, bboxes_format, bboxes_image_size, output_size)
if bboxes.ndim < 2: if bboxes.ndim < 2:
bboxes = [bboxes] bboxes = [bboxes]
......
...@@ -83,7 +83,7 @@ class BoundingBox(_Feature): ...@@ -83,7 +83,7 @@ class BoundingBox(_Feature):
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: bool = False, antialias: bool = False,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size) output = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size)
if isinstance(size, int): if isinstance(size, int):
size = [size] size = [size]
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1]) image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
...@@ -95,7 +95,7 @@ class BoundingBox(_Feature): ...@@ -95,7 +95,7 @@ class BoundingBox(_Feature):
def center_crop(self, output_size: List[int]) -> BoundingBox: def center_crop(self, output_size: List[int]) -> BoundingBox:
output = self._F.center_crop_bounding_box( output = self._F.center_crop_bounding_box(
self, format=self.format, output_size=output_size, image_size=self.image_size self, format=self.format, image_size=self.image_size, output_size=output_size
) )
if isinstance(output_size, int): if isinstance(output_size, int):
output_size = [output_size] output_size = [output_size]
...@@ -126,7 +126,7 @@ class BoundingBox(_Feature): ...@@ -126,7 +126,7 @@ class BoundingBox(_Feature):
if not isinstance(padding, int): if not isinstance(padding, int):
padding = list(padding) padding = list(padding)
output = self._F.pad_bounding_box(self, padding, format=self.format, padding_mode=padding_mode) output = self._F.pad_bounding_box(self, format=self.format, padding=padding, padding_mode=padding_mode)
# Update output image size: # Update output image size:
left, right, top, bottom = self._F._geometry._parse_pad_padding(padding) left, right, top, bottom = self._F._geometry._parse_pad_padding(padding)
......
...@@ -10,11 +10,11 @@ erase_image_tensor = _FT.erase ...@@ -10,11 +10,11 @@ erase_image_tensor = _FT.erase
@torch.jit.unused @torch.jit.unused
def erase_image_pil( def erase_image_pil(
img: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
) -> PIL.Image.Image: ) -> PIL.Image.Image:
t_img = pil_to_tensor(img) t_img = pil_to_tensor(image)
output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace) output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return to_pil_image(output, mode=img.mode) return to_pil_image(output, mode=image.mode)
def erase( def erase(
......
...@@ -116,7 +116,7 @@ def resize_image_tensor( ...@@ -116,7 +116,7 @@ def resize_image_tensor(
@torch.jit.unused @torch.jit.unused
def resize_image_pil( def resize_image_pil(
img: PIL.Image.Image, image: PIL.Image.Image,
size: Union[Sequence[int], int], size: Union[Sequence[int], int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
...@@ -125,8 +125,8 @@ def resize_image_pil( ...@@ -125,8 +125,8 @@ def resize_image_pil(
size = [size, size] size = [size, size]
# Explicitly cast size to list otherwise mypy issue: incompatible type "Sequence[int]"; expected "List[int]" # Explicitly cast size to list otherwise mypy issue: incompatible type "Sequence[int]"; expected "List[int]"
size: List[int] = list(size) size: List[int] = list(size)
size = _compute_resized_output_size(img.size[::-1], size=size, max_size=max_size) size = _compute_resized_output_size(image.size[::-1], size=size, max_size=max_size)
return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation]) return _FP.resize(image, size, interpolation=pil_modes_mapping[interpolation])
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor: def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
...@@ -145,7 +145,7 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N ...@@ -145,7 +145,7 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N
def resize_bounding_box( def resize_bounding_box(
bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int], max_size: Optional[int] = None bounding_box: torch.Tensor, image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
) -> torch.Tensor: ) -> torch.Tensor:
if isinstance(size, int): if isinstance(size, int):
size = [size] size = [size]
...@@ -228,7 +228,7 @@ def _affine_parse_args( ...@@ -228,7 +228,7 @@ def _affine_parse_args(
def affine_image_tensor( def affine_image_tensor(
img: torch.Tensor, image: torch.Tensor,
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
scale: float, scale: float,
...@@ -237,12 +237,12 @@ def affine_image_tensor( ...@@ -237,12 +237,12 @@ def affine_image_tensor(
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if img.numel() == 0: if image.numel() == 0:
return img return image
num_channels, height, width = img.shape[-3:] num_channels, height, width = image.shape[-3:]
extra_dims = img.shape[:-3] extra_dims = image.shape[:-3]
img = img.view(-1, num_channels, height, width) image = image.view(-1, num_channels, height, width)
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
...@@ -254,13 +254,13 @@ def affine_image_tensor( ...@@ -254,13 +254,13 @@ def affine_image_tensor(
translate_f = [1.0 * t for t in translate] translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
output = _FT.affine(img, matrix, interpolation=interpolation.value, fill=fill) output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill)
return output.view(extra_dims + (num_channels, height, width)) return output.view(extra_dims + (num_channels, height, width))
@torch.jit.unused @torch.jit.unused
def affine_image_pil( def affine_image_pil(
img: PIL.Image.Image, image: PIL.Image.Image,
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
scale: float, scale: float,
...@@ -275,11 +275,11 @@ def affine_image_pil( ...@@ -275,11 +275,11 @@ def affine_image_pil(
# it is visually better to estimate the center without 0.5 offset # it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None: if center is None:
_, height, width = get_dimensions_image_pil(img) _, height, width = get_dimensions_image_pil(image)
center = [width * 0.5, height * 0.5] center = [width * 0.5, height * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill) return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
def _affine_bounding_box_xyxy( def _affine_bounding_box_xyxy(
...@@ -456,15 +456,15 @@ def affine( ...@@ -456,15 +456,15 @@ def affine(
def rotate_image_tensor( def rotate_image_tensor(
img: torch.Tensor, image: torch.Tensor,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_channels, height, width = img.shape[-3:] num_channels, height, width = image.shape[-3:]
extra_dims = img.shape[:-3] extra_dims = image.shape[:-3]
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
...@@ -478,24 +478,24 @@ def rotate_image_tensor( ...@@ -478,24 +478,24 @@ def rotate_image_tensor(
# we need to set -angle. # we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
if img.numel() > 0: if image.numel() > 0:
img = _FT.rotate( image = _FT.rotate(
img.view(-1, num_channels, height, width), image.view(-1, num_channels, height, width),
matrix, matrix,
interpolation=interpolation.value, interpolation=interpolation.value,
expand=expand, expand=expand,
fill=fill, fill=fill,
) )
new_height, new_width = img.shape[-2:] new_height, new_width = image.shape[-2:]
else: else:
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height) new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)
return img.view(extra_dims + (num_channels, new_height, new_width)) return image.view(extra_dims + (num_channels, new_height, new_width))
@torch.jit.unused @torch.jit.unused
def rotate_image_pil( def rotate_image_pil(
img: PIL.Image.Image, image: PIL.Image.Image,
angle: float, angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
...@@ -507,7 +507,7 @@ def rotate_image_pil( ...@@ -507,7 +507,7 @@ def rotate_image_pil(
center = None center = None
return _FP.rotate( return _FP.rotate(
img, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center image, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
) )
...@@ -592,46 +592,46 @@ pad_image_pil = _FP.pad ...@@ -592,46 +592,46 @@ pad_image_pil = _FP.pad
def pad_image_tensor( def pad_image_tensor(
img: torch.Tensor, image: torch.Tensor,
padding: Union[int, List[int]], padding: Union[int, List[int]],
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
if fill is None: if fill is None:
# This is a JIT workaround # This is a JIT workaround
return _pad_with_scalar_fill(img, padding, fill=None, padding_mode=padding_mode) return _pad_with_scalar_fill(image, padding, fill=None, padding_mode=padding_mode)
elif isinstance(fill, (int, float)) or len(fill) == 1: elif isinstance(fill, (int, float)) or len(fill) == 1:
fill_number = fill[0] if isinstance(fill, list) else fill fill_number = fill[0] if isinstance(fill, list) else fill
return _pad_with_scalar_fill(img, padding, fill=fill_number, padding_mode=padding_mode) return _pad_with_scalar_fill(image, padding, fill=fill_number, padding_mode=padding_mode)
else: else:
return _pad_with_vector_fill(img, padding, fill=fill, padding_mode=padding_mode) return _pad_with_vector_fill(image, padding, fill=fill, padding_mode=padding_mode)
def _pad_with_scalar_fill( def _pad_with_scalar_fill(
img: torch.Tensor, image: torch.Tensor,
padding: Union[int, List[int]], padding: Union[int, List[int]],
fill: Union[int, float, None], fill: Union[int, float, None],
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
num_channels, height, width = img.shape[-3:] num_channels, height, width = image.shape[-3:]
extra_dims = img.shape[:-3] extra_dims = image.shape[:-3]
if img.numel() > 0: if image.numel() > 0:
img = _FT.pad( image = _FT.pad(
img=img.view(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode img=image.view(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
) )
new_height, new_width = img.shape[-2:] new_height, new_width = image.shape[-2:]
else: else:
left, right, top, bottom = _FT._parse_pad_padding(padding) left, right, top, bottom = _FT._parse_pad_padding(padding)
new_height = height + top + bottom new_height = height + top + bottom
new_width = width + left + right new_width = width + left + right
return img.view(extra_dims + (num_channels, new_height, new_width)) return image.view(extra_dims + (num_channels, new_height, new_width))
# TODO: This should be removed once pytorch pad supports non-scalar padding values # TODO: This should be removed once pytorch pad supports non-scalar padding values
def _pad_with_vector_fill( def _pad_with_vector_fill(
img: torch.Tensor, image: torch.Tensor,
padding: Union[int, List[int]], padding: Union[int, List[int]],
fill: List[float], fill: List[float],
padding_mode: str = "constant", padding_mode: str = "constant",
...@@ -639,9 +639,9 @@ def _pad_with_vector_fill( ...@@ -639,9 +639,9 @@ def _pad_with_vector_fill(
if padding_mode != "constant": if padding_mode != "constant":
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar") raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
output = _pad_with_scalar_fill(img, padding, fill=0, padding_mode="constant") output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant")
left, right, top, bottom = _parse_pad_padding(padding) left, right, top, bottom = _parse_pad_padding(padding)
fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1) fill = torch.tensor(fill, dtype=image.dtype, device=image.device).view(-1, 1, 1)
if top > 0: if top > 0:
output[..., :top, :] = fill output[..., :top, :] = fill
...@@ -672,7 +672,7 @@ def pad_mask( ...@@ -672,7 +672,7 @@ def pad_mask(
else: else:
needs_squeeze = False needs_squeeze = False
output = pad_image_tensor(img=mask, padding=padding, fill=fill, padding_mode=padding_mode) output = pad_image_tensor(mask, padding=padding, fill=fill, padding_mode=padding_mode)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
...@@ -682,8 +682,8 @@ def pad_mask( ...@@ -682,8 +682,8 @@ def pad_mask(
def pad_bounding_box( def pad_bounding_box(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
padding: Union[int, List[int]],
format: features.BoundingBoxFormat, format: features.BoundingBoxFormat,
padding: Union[int, List[int]],
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
if padding_mode not in ["constant"]: if padding_mode not in ["constant"]:
...@@ -755,22 +755,22 @@ def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: i ...@@ -755,22 +755,22 @@ def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: i
def perspective_image_tensor( def perspective_image_tensor(
img: torch.Tensor, image: torch.Tensor,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
return _FT.perspective(img, perspective_coeffs, interpolation=interpolation.value, fill=fill) return _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill)
@torch.jit.unused @torch.jit.unused
def perspective_image_pil( def perspective_image_pil(
img: PIL.Image.Image, image: PIL.Image.Image,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.BICUBIC, interpolation: InterpolationMode = InterpolationMode.BICUBIC,
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
return _FP.perspective(img, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill) return _FP.perspective(image, perspective_coeffs, interpolation=pil_modes_mapping[interpolation], fill=fill)
def perspective_bounding_box( def perspective_bounding_box(
...@@ -894,24 +894,24 @@ def perspective( ...@@ -894,24 +894,24 @@ def perspective(
def elastic_image_tensor( def elastic_image_tensor(
img: torch.Tensor, image: torch.Tensor,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
) -> torch.Tensor: ) -> torch.Tensor:
return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) return _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)
@torch.jit.unused @torch.jit.unused
def elastic_image_pil( def elastic_image_pil(
img: PIL.Image.Image, image: PIL.Image.Image,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
t_img = pil_to_tensor(img) t_img = pil_to_tensor(image)
output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill)
return to_pil_image(output, mode=img.mode) return to_pil_image(output, mode=image.mode)
def elastic_bounding_box( def elastic_bounding_box(
...@@ -1016,44 +1016,44 @@ def _center_crop_compute_crop_anchor( ...@@ -1016,44 +1016,44 @@ def _center_crop_compute_crop_anchor(
return crop_top, crop_left return crop_top, crop_left
def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
_, image_height, image_width = get_dimensions_image_tensor(img) _, image_height, image_width = get_dimensions_image_tensor(image)
if crop_height > image_height or crop_width > image_width: if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_tensor(img, padding_ltrb, fill=0) image = pad_image_tensor(image, padding_ltrb, fill=0)
_, image_height, image_width = get_dimensions_image_tensor(img) _, image_height, image_width = get_dimensions_image_tensor(image)
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
return img return image
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_tensor(img, crop_top, crop_left, crop_height, crop_width) return crop_image_tensor(image, crop_top, crop_left, crop_height, crop_width)
@torch.jit.unused @torch.jit.unused
def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
_, image_height, image_width = get_dimensions_image_pil(img) _, image_height, image_width = get_dimensions_image_pil(image)
if crop_height > image_height or crop_width > image_width: if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
img = pad_image_pil(img, padding_ltrb, fill=0) image = pad_image_pil(image, padding_ltrb, fill=0)
_, image_height, image_width = get_dimensions_image_pil(img) _, image_height, image_width = get_dimensions_image_pil(image)
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
return img return image
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_pil(img, crop_top, crop_left, crop_height, crop_width) return crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
def center_crop_bounding_box( def center_crop_bounding_box(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
format: features.BoundingBoxFormat, format: features.BoundingBoxFormat,
output_size: List[int],
image_size: Tuple[int, int], image_size: Tuple[int, int],
output_size: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size)
...@@ -1067,7 +1067,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor ...@@ -1067,7 +1067,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor
else: else:
needs_squeeze = False needs_squeeze = False
output = center_crop_image_tensor(img=mask, output_size=output_size) output = center_crop_image_tensor(image=mask, output_size=output_size)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
...@@ -1085,7 +1085,7 @@ def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features ...@@ -1085,7 +1085,7 @@ def center_crop(inpt: features.InputTypeJIT, output_size: List[int]) -> features
def resized_crop_image_tensor( def resized_crop_image_tensor(
img: torch.Tensor, image: torch.Tensor,
top: int, top: int,
left: int, left: int,
height: int, height: int,
...@@ -1094,13 +1094,13 @@ def resized_crop_image_tensor( ...@@ -1094,13 +1094,13 @@ def resized_crop_image_tensor(
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias: bool = False, antialias: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
img = crop_image_tensor(img, top, left, height, width) image = crop_image_tensor(image, top, left, height, width)
return resize_image_tensor(img, size, interpolation=interpolation, antialias=antialias) return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias)
@torch.jit.unused @torch.jit.unused
def resized_crop_image_pil( def resized_crop_image_pil(
img: PIL.Image.Image, image: PIL.Image.Image,
top: int, top: int,
left: int, left: int,
height: int, height: int,
...@@ -1108,8 +1108,8 @@ def resized_crop_image_pil( ...@@ -1108,8 +1108,8 @@ def resized_crop_image_pil(
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
img = crop_image_pil(img, top, left, height, width) image = crop_image_pil(image, top, left, height, width)
return resize_image_pil(img, size, interpolation=interpolation) return resize_image_pil(image, size, interpolation=interpolation)
def resized_crop_bounding_box( def resized_crop_bounding_box(
...@@ -1122,7 +1122,7 @@ def resized_crop_bounding_box( ...@@ -1122,7 +1122,7 @@ def resized_crop_bounding_box(
size: List[int], size: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
bounding_box = crop_bounding_box(bounding_box, format, top, left) bounding_box = crop_bounding_box(bounding_box, format, top, left)
return resize_bounding_box(bounding_box, size, (height, width)) return resize_bounding_box(bounding_box, (height, width), size)
def resized_crop_mask( def resized_crop_mask(
...@@ -1172,40 +1172,40 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: ...@@ -1172,40 +1172,40 @@ def _parse_five_crop_size(size: List[int]) -> List[int]:
def five_crop_image_tensor( def five_crop_image_tensor(
img: torch.Tensor, size: List[int] image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
crop_height, crop_width = _parse_five_crop_size(size) crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_tensor(img) _, image_height, image_width = get_dimensions_image_tensor(image)
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}" msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width))) raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop_image_tensor(img, 0, 0, crop_height, crop_width) tl = crop_image_tensor(image, 0, 0, crop_height, crop_width)
tr = crop_image_tensor(img, 0, image_width - crop_width, crop_height, crop_width) tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_tensor(img, image_height - crop_height, 0, crop_height, crop_width) bl = crop_image_tensor(image, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_tensor(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width) br = crop_image_tensor(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_tensor(img, [crop_height, crop_width]) center = center_crop_image_tensor(image, [crop_height, crop_width])
return tl, tr, bl, br, center return tl, tr, bl, br, center
@torch.jit.unused @torch.jit.unused
def five_crop_image_pil( def five_crop_image_pil(
img: PIL.Image.Image, size: List[int] image: PIL.Image.Image, size: List[int]
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
crop_height, crop_width = _parse_five_crop_size(size) crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_pil(img) _, image_height, image_width = get_dimensions_image_pil(image)
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}" msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width))) raise ValueError(msg.format(size, (image_height, image_width)))
tl = crop_image_pil(img, 0, 0, crop_height, crop_width) tl = crop_image_pil(image, 0, 0, crop_height, crop_width)
tr = crop_image_pil(img, 0, image_width - crop_width, crop_height, crop_width) tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_pil(img, image_height - crop_height, 0, crop_height, crop_width) bl = crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_pil(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width) br = crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_pil(img, [crop_height, crop_width]) center = center_crop_image_pil(image, [crop_height, crop_width])
return tl, tr, bl, br, center return tl, tr, bl, br, center
...@@ -1225,29 +1225,29 @@ def five_crop( ...@@ -1225,29 +1225,29 @@ def five_crop(
return five_crop_image_pil(inpt, size) return five_crop_image_pil(inpt, size)
def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]: def ten_crop_image_tensor(image: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
tl, tr, bl, br, center = five_crop_image_tensor(img, size) tl, tr, bl, br, center = five_crop_image_tensor(image, size)
if vertical_flip: if vertical_flip:
img = vertical_flip_image_tensor(img) image = vertical_flip_image_tensor(image)
else: else:
img = horizontal_flip_image_tensor(img) image = horizontal_flip_image_tensor(image)
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(img, size) tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(image, size)
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
@torch.jit.unused @torch.jit.unused
def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]: def ten_crop_image_pil(image: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]:
tl, tr, bl, br, center = five_crop_image_pil(img, size) tl, tr, bl, br, center = five_crop_image_pil(image, size)
if vertical_flip: if vertical_flip:
img = vertical_flip_image_pil(img) image = vertical_flip_image_pil(image)
else: else:
img = horizontal_flip_image_pil(img) image = horizontal_flip_image_pil(image)
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size) tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(image, size)
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
......
...@@ -21,7 +21,7 @@ def normalize( ...@@ -21,7 +21,7 @@ def normalize(
def gaussian_blur_image_tensor( def gaussian_blur_image_tensor(
img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: consider deprecating integers from sigma on the future # TODO: consider deprecating integers from sigma on the future
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
...@@ -47,16 +47,16 @@ def gaussian_blur_image_tensor( ...@@ -47,16 +47,16 @@ def gaussian_blur_image_tensor(
if s <= 0.0: if s <= 0.0:
raise ValueError(f"sigma should have positive values. Got {sigma}") raise ValueError(f"sigma should have positive values. Got {sigma}")
return _FT.gaussian_blur(img, kernel_size, sigma) return _FT.gaussian_blur(image, kernel_size, sigma)
@torch.jit.unused @torch.jit.unused
def gaussian_blur_image_pil( def gaussian_blur_image_pil(
img: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> PIL.Image.Image: ) -> PIL.Image.Image:
t_img = pil_to_tensor(img) t_img = pil_to_tensor(image)
output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma) output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma)
return to_pil_image(output, mode=img.mode) return to_pil_image(output, mode=image.mode)
def gaussian_blur( def gaussian_blur(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment