Unverified Commit f49edd3b authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fixed fill type in AA (#6621)



* [proto] Fixed fill type in AA

* Fixed missed typehints

* Set fill as None by default

* Another fix
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 0fcfaa13
import math
import numbers
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union
import PIL.Image
import torch
......@@ -10,7 +9,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_chw
from ._utils import _isinstance
from ._utils import _isinstance, _setup_fill_arg, FillType
K = TypeVar("K")
V = TypeVar("V")
......@@ -21,14 +20,11 @@ class _AutoAugmentBase(Transform):
self,
*,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
) -> None:
super().__init__()
self.interpolation = interpolation
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
self.fill = fill
self.fill = _setup_fill_arg(fill)
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
keys = tuple(dct.keys())
......@@ -63,19 +59,14 @@ class _AutoAugmentBase(Transform):
def _apply_image_transform(
self,
image: Any,
image: Union[torch.Tensor, PIL.Image.Image, features.Image],
transform_id: str,
magnitude: float,
interpolation: InterpolationMode,
fill: Union[int, float, Sequence[int], Sequence[float]],
fill: Union[Dict[Type, FillType], Dict[Type, None]],
) -> Any:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we have to put fill as None if fill == 0
# This is due to BC with stable API which has fill = None by default
fill_ = F._geometry._convert_fill_arg(fill)
if isinstance(fill, int) and fill == 0:
fill_ = None
fill_ = fill[type(image)]
fill_ = F._geometry._convert_fill_arg(fill_)
if transform_id == "Identity":
return image
......@@ -186,7 +177,7 @@ class AutoAugment(_AutoAugmentBase):
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
......@@ -286,7 +277,7 @@ class AutoAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)
_, height, width = get_chw(image)
policy = self._policies[int(torch.randint(len(self._policies), ()))]
......@@ -346,7 +337,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
......@@ -402,7 +393,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self,
num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
......@@ -462,7 +453,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0,
all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
......
......@@ -11,10 +11,7 @@ from torchvision.transforms import functional as _F
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from ._utils import query_chw
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
from ._utils import DType, query_chw
class ToTensor(Transform):
......
import math
import numbers
import warnings
from collections import defaultdict
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
import PIL.Image
......@@ -14,11 +13,20 @@ from torchvision.transforms.functional import _get_perspective_coeffs
from typing_extensions import Literal
from ._transform import _RandomApplyTransform
from ._utils import _check_sequence_input, _setup_angle, _setup_size, has_all, has_any, query_bounding_box, query_chw
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
FillType = Union[int, float, Sequence[int], Sequence[float]]
from ._utils import (
_check_padding_arg,
_check_padding_mode_arg,
_check_sequence_input,
_setup_angle,
_setup_fill_arg,
_setup_size,
DType,
FillType,
has_all,
has_any,
query_bounding_box,
query_chw,
)
class RandomHorizontalFlip(_RandomApplyTransform):
......@@ -201,40 +209,6 @@ class TenCrop(Transform):
return super().forward(*inputs)
def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
else:
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
_check_fill_arg(fill)
if isinstance(fill, dict):
return fill
return defaultdict(lambda: fill) # type: ignore[arg-type, return-value]
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
class Pad(Transform):
def __init__(
self,
......
from typing import Any, Callable, Tuple, Type, Union
import numbers
from collections import defaultdict
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten
from torchvision._utils import sequence_to_str
from torchvision.prototype import features
......@@ -8,6 +13,49 @@ from torchvision.prototype import features
from torchvision.prototype.transforms.functional._meta import get_chw
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from typing_extensions import Literal
# Type shortcuts:
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
FillType = Union[int, float, Sequence[int], Sequence[float]]
def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
else:
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
def _setup_fill_arg(
fill: Optional[Union[FillType, Dict[Type, FillType]]]
) -> Union[Dict[Type, FillType], Dict[Type, None]]:
_check_fill_arg(fill)
if isinstance(fill, dict):
return fill
return defaultdict(lambda: fill) # type: ignore[return-value]
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
def query_bounding_box(sample: Any) -> features.BoundingBox:
flat_sample, _ = tree_flatten(sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment