Unverified Commit 759c5b6d authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Added typing annotations to transforms/functional_pil (#4234)



* fix

* add functional PIL typings

* fix types

* fix types

* fix a small one

* small fix

* fix type

* fix interpolation types
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 98cb4ead
import numbers import numbers
from typing import Any, List, Sequence from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -34,7 +34,7 @@ def _get_image_num_channels(img: Any) -> int: ...@@ -34,7 +34,7 @@ def _get_image_num_channels(img: Any) -> int:
@torch.jit.unused @torch.jit.unused
def hflip(img): def hflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -42,7 +42,7 @@ def hflip(img): ...@@ -42,7 +42,7 @@ def hflip(img):
@torch.jit.unused @torch.jit.unused
def vflip(img): def vflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -50,7 +50,7 @@ def vflip(img): ...@@ -50,7 +50,7 @@ def vflip(img):
@torch.jit.unused @torch.jit.unused
def adjust_brightness(img, brightness_factor): def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -60,7 +60,7 @@ def adjust_brightness(img, brightness_factor): ...@@ -60,7 +60,7 @@ def adjust_brightness(img, brightness_factor):
@torch.jit.unused @torch.jit.unused
def adjust_contrast(img, contrast_factor): def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -70,7 +70,7 @@ def adjust_contrast(img, contrast_factor): ...@@ -70,7 +70,7 @@ def adjust_contrast(img, contrast_factor):
@torch.jit.unused @torch.jit.unused
def adjust_saturation(img, saturation_factor): def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -80,7 +80,7 @@ def adjust_saturation(img, saturation_factor): ...@@ -80,7 +80,7 @@ def adjust_saturation(img, saturation_factor):
@torch.jit.unused @torch.jit.unused
def adjust_hue(img, hue_factor): def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
if not(-0.5 <= hue_factor <= 0.5): if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
...@@ -104,7 +104,12 @@ def adjust_hue(img, hue_factor): ...@@ -104,7 +104,12 @@ def adjust_hue(img, hue_factor):
@torch.jit.unused @torch.jit.unused
def adjust_gamma(img, gamma, gain=1): def adjust_gamma(
img: Image.Image,
gamma: float,
gain: float = 1.0,
) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -121,7 +126,13 @@ def adjust_gamma(img, gamma, gain=1): ...@@ -121,7 +126,13 @@ def adjust_gamma(img, gamma, gain=1):
@torch.jit.unused @torch.jit.unused
def pad(img, padding, fill=0, padding_mode="constant"): def pad(
img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
padding_mode: str = "constant",
) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
...@@ -196,7 +207,14 @@ def pad(img, padding, fill=0, padding_mode="constant"): ...@@ -196,7 +207,14 @@ def pad(img, padding, fill=0, padding_mode="constant"):
@torch.jit.unused @torch.jit.unused
def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Image.Image: def crop(
img: Image.Image,
top: int,
left: int,
height: int,
width: int,
) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -204,7 +222,13 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag ...@@ -204,7 +222,13 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag
@torch.jit.unused @torch.jit.unused
def resize(img, size, interpolation=Image.BILINEAR, max_size=None): def resize(
img: Image.Image,
size: Union[Sequence[int], int],
interpolation: int = Image.BILINEAR,
max_size: Optional[int] = None,
) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
...@@ -242,7 +266,12 @@ def resize(img, size, interpolation=Image.BILINEAR, max_size=None): ...@@ -242,7 +266,12 @@ def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
@torch.jit.unused @torch.jit.unused
def _parse_fill(fill, img, name="fillcolor"): def _parse_fill(
fill: Optional[Union[float, List[float], Tuple[float, ...]]],
img: Image.Image,
name: str = "fillcolor",
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
# Process fill color for affine transforms # Process fill color for affine transforms
num_bands = len(img.getbands()) num_bands = len(img.getbands())
if fill is None: if fill is None:
...@@ -261,7 +290,13 @@ def _parse_fill(fill, img, name="fillcolor"): ...@@ -261,7 +290,13 @@ def _parse_fill(fill, img, name="fillcolor"):
@torch.jit.unused @torch.jit.unused
def affine(img, matrix, interpolation=0, fill=None): def affine(
img: Image.Image,
matrix: List[float],
interpolation: int = Image.NEAREST,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -271,7 +306,15 @@ def affine(img, matrix, interpolation=0, fill=None): ...@@ -271,7 +306,15 @@ def affine(img, matrix, interpolation=0, fill=None):
@torch.jit.unused @torch.jit.unused
def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None): def rotate(
img: Image.Image,
angle: float,
interpolation: int = Image.NEAREST,
expand: bool = False,
center: Optional[Tuple[int, int]] = None,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img))) raise TypeError("img should be PIL Image. Got {}".format(type(img)))
...@@ -280,7 +323,13 @@ def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None): ...@@ -280,7 +323,13 @@ def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):
@torch.jit.unused @torch.jit.unused
def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None): def perspective(
img: Image.Image,
perspective_coeffs: float,
interpolation: int = Image.BICUBIC,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -290,7 +339,7 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None) ...@@ -290,7 +339,7 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None)
@torch.jit.unused @torch.jit.unused
def to_grayscale(img, num_output_channels): def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -308,28 +357,28 @@ def to_grayscale(img, num_output_channels): ...@@ -308,28 +357,28 @@ def to_grayscale(img, num_output_channels):
@torch.jit.unused @torch.jit.unused
def invert(img): def invert(img: Image.Image) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.invert(img) return ImageOps.invert(img)
@torch.jit.unused @torch.jit.unused
def posterize(img, bits): def posterize(img: Image.Image, bits: int) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.posterize(img, bits) return ImageOps.posterize(img, bits)
@torch.jit.unused @torch.jit.unused
def solarize(img, threshold): def solarize(img: Image.Image, threshold: int) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.solarize(img, threshold) return ImageOps.solarize(img, threshold)
@torch.jit.unused @torch.jit.unused
def adjust_sharpness(img, sharpness_factor): def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
...@@ -339,14 +388,14 @@ def adjust_sharpness(img, sharpness_factor): ...@@ -339,14 +388,14 @@ def adjust_sharpness(img, sharpness_factor):
@torch.jit.unused @torch.jit.unused
def autocontrast(img): def autocontrast(img: Image.Image) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.autocontrast(img) return ImageOps.autocontrast(img)
@torch.jit.unused @torch.jit.unused
def equalize(img): def equalize(img: Image.Image) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.equalize(img) return ImageOps.equalize(img)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment