Unverified Commit d23a6e16 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

More mypy fixes/ignores (#8412)

parent f766d7ac
import math import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -94,6 +94,8 @@ class _AutoAugmentBase(Transform): ...@@ -94,6 +94,8 @@ class _AutoAugmentBase(Transform):
interpolation: Union[InterpolationMode, int], interpolation: Union[InterpolationMode, int],
fill: Dict[Union[Type, str], _FillTypeJIT], fill: Dict[Union[Type, str], _FillTypeJIT],
) -> ImageOrVideo: ) -> ImageOrVideo:
# Note: this cast is wrong and is only here to make mypy happy (it disagrees with torchscript)
image = cast(torch.Tensor, image)
fill_ = _get_fill(fill, type(image)) fill_ = _get_fill(fill, type(image))
if transform_id == "Identity": if transform_id == "Identity":
...@@ -322,7 +324,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -322,7 +324,7 @@ class AutoAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video) height, width = get_size(image_or_video) # type: ignore[arg-type]
policy = self._policies[int(torch.randint(len(self._policies), ()))] policy = self._policies[int(torch.randint(len(self._policies), ()))]
...@@ -411,7 +413,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -411,7 +413,7 @@ class RandAugment(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video) height, width = get_size(image_or_video) # type: ignore[arg-type]
for _ in range(self.num_ops): for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
...@@ -480,7 +482,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -480,7 +482,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs) flat_inputs_with_spec, image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(image_or_video) height, width = get_size(image_or_video) # type: ignore[arg-type]
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
...@@ -572,7 +574,7 @@ class AugMix(_AutoAugmentBase): ...@@ -572,7 +574,7 @@ class AugMix(_AutoAugmentBase):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs) flat_inputs_with_spec, orig_image_or_video = self._flatten_and_extract_image_or_video(inputs)
height, width = get_size(orig_image_or_video) height, width = get_size(orig_image_or_video) # type: ignore[arg-type]
if isinstance(orig_image_or_video, torch.Tensor): if isinstance(orig_image_or_video, torch.Tensor):
image_or_video = orig_image_or_video image_or_video = orig_image_or_video
...@@ -613,9 +615,7 @@ class AugMix(_AutoAugmentBase): ...@@ -613,9 +615,7 @@ class AugMix(_AutoAugmentBase):
else: else:
magnitude = 0.0 magnitude = 0.0
aug = self._apply_image_or_video_transform( aug = self._apply_image_or_video_transform(aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill) # type: ignore[assignment]
aug, transform_id, magnitude, interpolation=self.interpolation, fill=self._fill
)
mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) mix.add_(combined_weights[:, i].reshape(batch_dims) * aug)
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)
......
...@@ -730,7 +730,7 @@ def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch ...@@ -730,7 +730,7 @@ def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch
@_register_kernel_internal(permute_channels, PIL.Image.Image) @_register_kernel_internal(permute_channels, PIL.Image.Image)
def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image: def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image.Image:
return to_pil_image(permute_channels_image(pil_to_tensor(image), permutation=permutation)) return to_pil_image(permute_channels_image(pil_to_tensor(image), permutation=permutation))
......
...@@ -113,7 +113,7 @@ def vertical_flip_image(image: torch.Tensor) -> torch.Tensor: ...@@ -113,7 +113,7 @@ def vertical_flip_image(image: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(vertical_flip, PIL.Image.Image) @_register_kernel_internal(vertical_flip, PIL.Image.Image)
def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image: def _vertical_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.vflip(image) return _FP.vflip(image)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment