Unverified Commit d23a6e16 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

More mypy fixes/ignores (#8412)

parent f766d7ac
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
import PIL.Image
import torch
......@@ -94,6 +94,8 @@ class _AutoAugmentBase(Transform):
interpolation: Union[InterpolationMode, int],
fill: Dict[Union[Type, str], _FillTypeJIT],
) -> ImageOrVideo:
# Note: this cast is wrong and is only here to make mypy happy (it disagrees with torchscript)
image = cast(torch.Tensor, image)
fill_ = _get_fill(fill, type(image))
if transform_id == "Identity":
......@@ -322,7 +324,7 @@ class AutoAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video)
height, width = get_size(image_or_video) # type: ignore[arg-type]
policy = self._policies[int(torch.randint(len(self._policies), ()))]
......@@ -411,7 +413,7 @@ class RandAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video)
height, width = get_size(image_or_video) # type: ignore[arg-type]
for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
......@@ -480,7 +482,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video)
height, width = get_size(image_or_video) # type: ignore[arg-type]
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
......@@ -572,7 +574,7 @@ class AugMix(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(orig_image_or_video)
height, width = get_size(orig_image_or_video) # type: ignore[arg-type]
if isinstance(orig_image_or_video, torch.Tensor):
image_or_video = orig_image_or_video
......@@ -613,9 +615,7 @@ class AugMix(_AutoAugmentBase):
else:
magnitude = 0.0
aug = self._apply_image_or_video_transform(
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
)
aug = self._apply_image_or_video_transform(aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill) # type: ignore[assignment]
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
......
......@@ -730,7 +730,7 @@ def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch
@_register_kernel_internal(permute_channels, PIL.Image.Image)
def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image:
def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image.Image:
return to_pil_image(permute_channels_image(pil_to_tensor(image), permutation=permutation))
......
......@@ -113,7 +113,7 @@ def vertical_flip_image(image: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(vertical_flip, PIL.Image.Image)
def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
def _vertical_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.vflip(image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment