Unverified Commit 759c5b6d authored by Aditya Oke's avatar Aditya Oke Committed by GitHub
Browse files

Added typing annotations to transforms/functional_pil (#4234)



* fix

* add functional PIL typings

* fix types

* fix types

* fix a small one

* small fix

* fix type

* fix interpolation types
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 98cb4ead
import numbers
from typing import Any, List, Sequence
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
......@@ -34,7 +34,7 @@ def _get_image_num_channels(img: Any) -> int:
@torch.jit.unused
def hflip(img):
def hflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
......@@ -42,7 +42,7 @@ def hflip(img):
@torch.jit.unused
def vflip(img):
def vflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
......@@ -50,7 +50,7 @@ def vflip(img):
@torch.jit.unused
def adjust_brightness(img, brightness_factor):
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
......@@ -60,7 +60,7 @@ def adjust_brightness(img, brightness_factor):
@torch.jit.unused
def adjust_contrast(img, contrast_factor):
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
......@@ -70,7 +70,7 @@ def adjust_contrast(img, contrast_factor):
@torch.jit.unused
def adjust_saturation(img, saturation_factor):
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
......@@ -80,7 +80,7 @@ def adjust_saturation(img, saturation_factor):
@torch.jit.unused
def adjust_hue(img, hue_factor):
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
......@@ -104,7 +104,12 @@ def adjust_hue(img, hue_factor):
@torch.jit.unused
def adjust_gamma(img, gamma, gain=1):
def adjust_gamma(
img: Image.Image,
gamma: float,
gain: float = 1.0,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
......@@ -121,7 +126,13 @@ def adjust_gamma(img, gamma, gain=1):
@torch.jit.unused
def pad(img, padding, fill=0, padding_mode="constant"):
def pad(
img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
padding_mode: str = "constant",
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
......@@ -196,7 +207,14 @@ def pad(img, padding, fill=0, padding_mode="constant"):
@torch.jit.unused
def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Image.Image:
def crop(
img: Image.Image,
top: int,
left: int,
height: int,
width: int,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
......@@ -204,7 +222,13 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag
@torch.jit.unused
def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
def resize(
img: Image.Image,
size: Union[Sequence[int], int],
interpolation: int = Image.BILINEAR,
max_size: Optional[int] = None,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
......@@ -242,7 +266,12 @@ def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
@torch.jit.unused
def _parse_fill(fill, img, name="fillcolor"):
def _parse_fill(
fill: Optional[Union[float, List[float], Tuple[float, ...]]],
img: Image.Image,
name: str = "fillcolor",
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:
# Process fill color for affine transforms
num_bands = len(img.getbands())
if fill is None:
......@@ -261,7 +290,13 @@ def _parse_fill(fill, img, name="fillcolor"):
@torch.jit.unused
def affine(img, matrix, interpolation=0, fill=None):
def affine(
img: Image.Image,
matrix: List[float],
interpolation: int = Image.NEAREST,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
......@@ -271,7 +306,15 @@ def affine(img, matrix, interpolation=0, fill=None):
@torch.jit.unused
def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):
def rotate(
img: Image.Image,
angle: float,
interpolation: int = Image.NEAREST,
expand: bool = False,
center: Optional[Tuple[int, int]] = None,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
......@@ -280,7 +323,13 @@ def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):
@torch.jit.unused
def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None):
def perspective(
img: Image.Image,
perspective_coeffs: float,
interpolation: int = Image.BICUBIC,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
......@@ -290,7 +339,7 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None)
@torch.jit.unused
def to_grayscale(img, num_output_channels):
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
......@@ -308,28 +357,28 @@ def to_grayscale(img, num_output_channels):
@torch.jit.unused
def invert(img):
def invert(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.invert(img)
@torch.jit.unused
def posterize(img, bits):
def posterize(img: Image.Image, bits: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.posterize(img, bits)
@torch.jit.unused
def solarize(img, threshold):
def solarize(img: Image.Image, threshold: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.solarize(img, threshold)
@torch.jit.unused
def adjust_sharpness(img, sharpness_factor):
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
......@@ -339,14 +388,14 @@ def adjust_sharpness(img, sharpness_factor):
@torch.jit.unused
def autocontrast(img):
def autocontrast(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.autocontrast(img)
@torch.jit.unused
def equalize(img):
def equalize(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.equalize(img)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment