"...tools/git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "fe851fbc27e4aebbbf1bd39b8538fc8807504bc9"
Unverified Commit 38175edb authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added typing annotations to transforms/autoaugment (#4226)



* style: Added typing annotations

* style: Fixed typing

* style: Fixed typing

* Remove unnecessary any.

* Update mypy.ini
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarVasilis Vryniotis <vvryniotis@fb.com>
parent 091cbcda
...@@ -36,10 +36,6 @@ ignore_errors = True ...@@ -36,10 +36,6 @@ ignore_errors = True
ignore_errors = True ignore_errors = True
[mypy-torchvision.transforms.autoaugment.*]
ignore_errors = True
[mypy-PIL.*] [mypy-PIL.*]
ignore_missing_imports = True ignore_missing_imports = True
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from enum import Enum from enum import Enum
from torch import Tensor from torch import Tensor
from typing import List, Tuple, Optional from typing import List, Tuple, Optional, Dict
from . import functional as F, InterpolationMode from . import functional as F, InterpolationMode
...@@ -19,7 +19,9 @@ class AutoAugmentPolicy(Enum): ...@@ -19,7 +19,9 @@ class AutoAugmentPolicy(Enum):
SVHN = "svhn" SVHN = "svhn"
def _get_transforms(policy: AutoAugmentPolicy): def _get_transforms( # type: ignore[return]
policy: AutoAugmentPolicy
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
if policy == AutoAugmentPolicy.IMAGENET: if policy == AutoAugmentPolicy.IMAGENET:
return [ return [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
...@@ -106,7 +108,7 @@ def _get_transforms(policy: AutoAugmentPolicy): ...@@ -106,7 +108,7 @@ def _get_transforms(policy: AutoAugmentPolicy):
] ]
def _get_magnitudes(): def _get_magnitudes() -> Dict[str, Tuple[Optional[Tensor], Optional[bool]]]:
_BINS = 10 _BINS = 10
return { return {
# name: (magnitudes, signed) # name: (magnitudes, signed)
...@@ -144,8 +146,12 @@ class AutoAugment(torch.nn.Module): ...@@ -144,8 +146,12 @@ class AutoAugment(torch.nn.Module):
image. If given a number, the value is used for all bands respectively. image. If given a number, the value is used for all bands respectively.
""" """
def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, def __init__(
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None): self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None
) -> None:
super().__init__() super().__init__()
self.policy = policy self.policy = policy
self.interpolation = interpolation self.interpolation = interpolation
...@@ -163,7 +169,7 @@ class AutoAugment(torch.nn.Module): ...@@ -163,7 +169,7 @@ class AutoAugment(torch.nn.Module):
Returns: Returns:
params required by the autoaugment transformation params required by the autoaugment transformation
""" """
policy_id = torch.randint(transform_num, (1,)).item() policy_id = int(torch.randint(transform_num, (1,)).item())
probs = torch.rand((2,)) probs = torch.rand((2,))
signs = torch.randint(2, (2,)) signs = torch.randint(2, (2,))
...@@ -172,7 +178,7 @@ class AutoAugment(torch.nn.Module): ...@@ -172,7 +178,7 @@ class AutoAugment(torch.nn.Module):
def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]: def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]:
return self._op_meta[name] return self._op_meta[name]
def forward(self, img: Tensor): def forward(self, img: Tensor) -> Tensor:
""" """
img (PIL Image or Tensor): Image to be transformed. img (PIL Image or Tensor): Image to be transformed.
...@@ -233,5 +239,5 @@ class AutoAugment(torch.nn.Module): ...@@ -233,5 +239,5 @@ class AutoAugment(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment