"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "32e16805a17401f5ef5ec825c808d645f5c26509"
Unverified Commit d95fbaf1 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Optimize Center Crop performance (#6880)



* Reducing unnecessary method calls

* Optimize pad branch.

* Remove unnecessary call to crop_image_tensor

* Fix linter
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent 72c59526
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from ._meta import _rgb_to_gray, convert_dtype_image_tensor, get_dimensions_image_tensor, get_num_channels_image_tensor from ._meta import _rgb_to_gray, convert_dtype_image_tensor
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
...@@ -45,7 +45,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float ...@@ -45,7 +45,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
if saturation_factor < 0: if saturation_factor < 0:
raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.")
c = get_num_channels_image_tensor(image) c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
...@@ -75,7 +75,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> ...@@ -75,7 +75,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
if contrast_factor < 0: if contrast_factor < 0:
raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.")
c = get_num_channels_image_tensor(image) c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 dtype = image.dtype if torch.is_floating_point(image) else torch.float32
...@@ -101,7 +101,7 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat ...@@ -101,7 +101,7 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat
def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
num_channels, height, width = get_dimensions_image_tensor(image) num_channels, height, width = image.shape[-3:]
if num_channels not in (1, 3): if num_channels not in (1, 3):
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")
...@@ -210,8 +210,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten ...@@ -210,8 +210,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
if not (-0.5 <= hue_factor <= 0.5): if not (-0.5 <= hue_factor <= 0.5):
raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].")
c = get_num_channels_image_tensor(image) c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
...@@ -342,8 +341,7 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp ...@@ -342,8 +341,7 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
c = get_num_channels_image_tensor(image) c = image.shape[-3]
if c not in [1, 3]: if c not in [1, 3]:
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
......
...@@ -16,12 +16,7 @@ from torchvision.transforms.functional import ( ...@@ -16,12 +16,7 @@ from torchvision.transforms.functional import (
) )
from torchvision.transforms.functional_tensor import _parse_pad_padding from torchvision.transforms.functional_tensor import _parse_pad_padding
from ._meta import ( from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
convert_format_bounding_box,
get_dimensions_image_tensor,
get_spatial_size_image_pil,
get_spatial_size_image_tensor,
)
horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip horizontal_flip_image_pil = _FP.hflip
...@@ -120,9 +115,9 @@ def resize_image_tensor( ...@@ -120,9 +115,9 @@ def resize_image_tensor(
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: bool = False, antialias: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
num_channels, old_height, old_width = get_dimensions_image_tensor(image) shape = image.shape
num_channels, old_height, old_width = shape[-3:]
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
extra_dims = image.shape[:-3]
if image.numel() > 0: if image.numel() > 0:
image = image.reshape(-1, num_channels, old_height, old_width) image = image.reshape(-1, num_channels, old_height, old_width)
...@@ -134,7 +129,7 @@ def resize_image_tensor( ...@@ -134,7 +129,7 @@ def resize_image_tensor(
antialias=antialias, antialias=antialias,
) )
return image.reshape(extra_dims + (num_channels, new_height, new_width)) return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
@torch.jit.unused @torch.jit.unused
...@@ -270,8 +265,8 @@ def affine_image_tensor( ...@@ -270,8 +265,8 @@ def affine_image_tensor(
if image.numel() == 0: if image.numel() == 0:
return image return image
num_channels, height, width = image.shape[-3:] shape = image.shape
extra_dims = image.shape[:-3] num_channels, height, width = shape[-3:]
image = image.reshape(-1, num_channels, height, width) image = image.reshape(-1, num_channels, height, width)
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
...@@ -285,7 +280,7 @@ def affine_image_tensor( ...@@ -285,7 +280,7 @@ def affine_image_tensor(
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill) output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill)
return output.reshape(extra_dims + (num_channels, height, width)) return output.reshape(shape)
@torch.jit.unused @torch.jit.unused
...@@ -511,8 +506,8 @@ def rotate_image_tensor( ...@@ -511,8 +506,8 @@ def rotate_image_tensor(
fill: features.FillTypeJIT = None, fill: features.FillTypeJIT = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
num_channels, height, width = image.shape[-3:] shape = image.shape
extra_dims = image.shape[:-3] num_channels, height, width = shape[-3:]
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
...@@ -538,7 +533,7 @@ def rotate_image_tensor( ...@@ -538,7 +533,7 @@ def rotate_image_tensor(
else: else:
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height) new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)
return image.reshape(extra_dims + (num_channels, new_height, new_width)) return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
@torch.jit.unused @torch.jit.unused
...@@ -675,8 +670,8 @@ def _pad_with_scalar_fill( ...@@ -675,8 +670,8 @@ def _pad_with_scalar_fill(
fill: Union[int, float, None], fill: Union[int, float, None],
padding_mode: str = "constant", padding_mode: str = "constant",
) -> torch.Tensor: ) -> torch.Tensor:
num_channels, height, width = image.shape[-3:] shape = image.shape
extra_dims = image.shape[:-3] num_channels, height, width = shape[-3:]
if image.numel() > 0: if image.numel() > 0:
image = _FT.pad( image = _FT.pad(
...@@ -688,7 +683,7 @@ def _pad_with_scalar_fill( ...@@ -688,7 +683,7 @@ def _pad_with_scalar_fill(
new_height = height + top + bottom new_height = height + top + bottom
new_width = width + left + right new_width = width + left + right
return image.reshape(extra_dims + (num_channels, new_height, new_width)) return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
# TODO: This should be removed once pytorch pad supports non-scalar padding values # TODO: This should be removed once pytorch pad supports non-scalar padding values
...@@ -1130,7 +1125,8 @@ elastic_transform = elastic ...@@ -1130,7 +1125,8 @@ elastic_transform = elastic
def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: def _center_crop_parse_output_size(output_size: List[int]) -> List[int]:
if isinstance(output_size, numbers.Number): if isinstance(output_size, numbers.Number):
return [int(output_size), int(output_size)] s = int(output_size)
return [s, s]
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
return [output_size[0], output_size[0]] return [output_size[0], output_size[0]]
else: else:
...@@ -1156,18 +1152,21 @@ def _center_crop_compute_crop_anchor( ...@@ -1156,18 +1152,21 @@ def _center_crop_compute_crop_anchor(
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_height, crop_width = _center_crop_parse_output_size(output_size)
image_height, image_width = get_spatial_size_image_tensor(image) shape = image.shape
if image.numel() == 0:
return image.reshape(shape[:-2] + (crop_height, crop_width))
image_height, image_width = shape[-2:]
if crop_height > image_height or crop_width > image_width: if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = pad_image_tensor(image, padding_ltrb, fill=0) image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0)
image_height, image_width = get_spatial_size_image_tensor(image) image_height, image_width = image.shape[-2:]
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
return image return image
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width)
return crop_image_tensor(image, crop_top, crop_left, crop_height, crop_width) return image[..., crop_top : (crop_top + crop_height), crop_left : (crop_left + crop_width)]
@torch.jit.unused @torch.jit.unused
...@@ -1332,7 +1331,7 @@ def five_crop_image_tensor( ...@@ -1332,7 +1331,7 @@ def five_crop_image_tensor(
image: torch.Tensor, size: List[int] image: torch.Tensor, size: List[int]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
crop_height, crop_width = _parse_five_crop_size(size) crop_height, crop_width = _parse_five_crop_size(size)
image_height, image_width = get_spatial_size_image_tensor(image) image_height, image_width = image.shape[-2:]
if crop_width > image_width or crop_height > image_height: if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}" msg = "Requested crop size {} is bigger than input size {}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment