Unverified Commit 38175edb authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added typing annotations to transforms/autoaugment (#4226)



* style: Added typing annotations

* style: Fixed typing

* style: Fixed typing

* Remove unnecessary any.

* Update mypy.ini
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
parent 091cbcda
......@@ -36,10 +36,6 @@ ignore_errors = True
ignore_errors = True
[mypy-torchvision.transforms.autoaugment.*]
ignore_errors = True
[mypy-PIL.*]
ignore_missing_imports = True
......
......@@ -3,7 +3,7 @@ import torch
from enum import Enum
from torch import Tensor
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Dict
from . import functional as F, InterpolationMode
......@@ -19,7 +19,9 @@ class AutoAugmentPolicy(Enum):
SVHN = "svhn"
def _get_transforms(policy: AutoAugmentPolicy):
def _get_transforms( # type: ignore[return]
policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
if policy == AutoAugmentPolicy.IMAGENET:
return [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
......@@ -106,7 +108,7 @@ def _get_transforms(policy: AutoAugmentPolicy):
]
def _get_magnitudes():
def _get_magnitudes() -> Dict[str, Tuple[Optional[Tensor], Optional[bool]]]:
_BINS = 10
return {
# name: (magnitudes, signed)
......@@ -144,8 +146,12 @@ class AutoAugment(torch.nn.Module):
image. If given a number, the value is used for all bands respectively.
"""
def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None):
def __init__(
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None
) -> None:
super().__init__()
self.policy = policy
self.interpolation = interpolation
......@@ -163,7 +169,7 @@ class AutoAugment(torch.nn.Module):
Returns:
params required by the autoaugment transformation
"""
policy_id = torch.randint(transform_num, (1,)).item()
policy_id = int(torch.randint(transform_num, (1,)).item())
probs = torch.rand((2,))
signs = torch.randint(2, (2,))
......@@ -172,7 +178,7 @@ class AutoAugment(torch.nn.Module):
def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]:
return self._op_meta[name]
def forward(self, img: Tensor):
def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
......@@ -233,5 +239,5 @@ class AutoAugment(torch.nn.Module):
return img
def __repr__(self):
def __repr__(self) -> str:
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment