"src/vscode:/vscode.git/clone" did not exist on "bdecc3cffd054c9c30520ef8551215867a0940ee"
Unverified Commit f49edd3b authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Fixed fill type in AA (#6621)



* [proto] Fixed fill type in AA

* Fixed missed typehints

* Set fill as None by default

* Another fix
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 0fcfaa13
import math import math
import numbers from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -10,7 +9,7 @@ from torchvision.prototype import features ...@@ -10,7 +9,7 @@ from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_chw from torchvision.prototype.transforms.functional._meta import get_chw
from ._utils import _isinstance from ._utils import _isinstance, _setup_fill_arg, FillType
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
...@@ -21,14 +20,11 @@ class _AutoAugmentBase(Transform): ...@@ -21,14 +20,11 @@ class _AutoAugmentBase(Transform):
self, self,
*, *,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.interpolation = interpolation self.interpolation = interpolation
self.fill = _setup_fill_arg(fill)
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
self.fill = fill
def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
keys = tuple(dct.keys()) keys = tuple(dct.keys())
...@@ -63,19 +59,14 @@ class _AutoAugmentBase(Transform): ...@@ -63,19 +59,14 @@ class _AutoAugmentBase(Transform):
def _apply_image_transform( def _apply_image_transform(
self, self,
image: Any, image: Union[torch.Tensor, PIL.Image.Image, features.Image],
transform_id: str, transform_id: str,
magnitude: float, magnitude: float,
interpolation: InterpolationMode, interpolation: InterpolationMode,
fill: Union[int, float, Sequence[int], Sequence[float]], fill: Union[Dict[Type, FillType], Dict[Type, None]],
) -> Any: ) -> Any:
fill_ = fill[type(image)]
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 fill_ = F._geometry._convert_fill_arg(fill_)
# So, we have to put fill as None if fill == 0
# This is due to BC with stable API which has fill = None by default
fill_ = F._geometry._convert_fill_arg(fill)
if isinstance(fill, int) and fill == 0:
fill_ = None
if transform_id == "Identity": if transform_id == "Identity":
return image return image
...@@ -186,7 +177,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -186,7 +177,7 @@ class AutoAugment(_AutoAugmentBase):
self, self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy self.policy = policy
...@@ -286,7 +277,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -286,7 +277,7 @@ class AutoAugment(_AutoAugmentBase):
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
id, image = self._extract_image(sample) id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image) _, height, width = get_chw(image)
policy = self._policies[int(torch.randint(len(self._policies), ()))] policy = self._policies[int(torch.randint(len(self._policies), ()))]
...@@ -346,7 +337,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -346,7 +337,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9, magnitude: int = 9,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops self.num_ops = num_ops
...@@ -402,7 +393,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -402,7 +393,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self, self,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
): ):
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
...@@ -462,7 +453,7 @@ class AugMix(_AutoAugmentBase): ...@@ -462,7 +453,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0, alpha: float = 1.0,
all_ops: bool = True, all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Union[int, float, Sequence[int], Sequence[float]] = 0, fill: Optional[Union[FillType, Dict[Type, FillType]]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10 self._PARAMETER_MAX = 10
......
...@@ -11,10 +11,7 @@ from torchvision.transforms import functional as _F ...@@ -11,10 +11,7 @@ from torchvision.transforms import functional as _F
from typing_extensions import Literal from typing_extensions import Literal
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import query_chw from ._utils import DType, query_chw
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
class ToTensor(Transform): class ToTensor(Transform):
......
import math import math
import numbers import numbers
import warnings import warnings
from collections import defaultdict
from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union from typing import Any, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
import PIL.Image import PIL.Image
...@@ -14,11 +13,20 @@ from torchvision.transforms.functional import _get_perspective_coeffs ...@@ -14,11 +13,20 @@ from torchvision.transforms.functional import _get_perspective_coeffs
from typing_extensions import Literal from typing_extensions import Literal
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import _check_sequence_input, _setup_angle, _setup_size, has_all, has_any, query_bounding_box, query_chw from ._utils import (
_check_padding_arg,
_check_padding_mode_arg,
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature] _check_sequence_input,
FillType = Union[int, float, Sequence[int], Sequence[float]] _setup_angle,
_setup_fill_arg,
_setup_size,
DType,
FillType,
has_all,
has_any,
query_bounding_box,
query_chw,
)
class RandomHorizontalFlip(_RandomApplyTransform): class RandomHorizontalFlip(_RandomApplyTransform):
...@@ -201,40 +209,6 @@ class TenCrop(Transform): ...@@ -201,40 +209,6 @@ class TenCrop(Transform):
return super().forward(*inputs) return super().forward(*inputs)
def _check_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
else:
if not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
def _setup_fill_arg(fill: Union[FillType, Dict[Type, FillType]]) -> Dict[Type, FillType]:
_check_fill_arg(fill)
if isinstance(fill, dict):
return fill
return defaultdict(lambda: fill) # type: ignore[arg-type, return-value]
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
class Pad(Transform): class Pad(Transform):
def __init__( def __init__(
self, self,
......
from typing import Any, Callable, Tuple, Type, Union import numbers
from collections import defaultdict
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch
from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_flatten
from torchvision._utils import sequence_to_str from torchvision._utils import sequence_to_str
from torchvision.prototype import features from torchvision.prototype import features
...@@ -8,6 +13,49 @@ from torchvision.prototype import features ...@@ -8,6 +13,49 @@ from torchvision.prototype import features
from torchvision.prototype.transforms.functional._meta import get_chw from torchvision.prototype.transforms.functional._meta import get_chw
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from typing_extensions import Literal
# Type shortcuts:
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
FillType = Union[int, float, Sequence[int], Sequence[float]]
def _check_fill_arg(fill: Optional[Union[FillType, Dict[Type, FillType]]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
_check_fill_arg(value)
else:
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg")
def _setup_fill_arg(
fill: Optional[Union[FillType, Dict[Type, FillType]]]
) -> Union[Dict[Type, FillType], Dict[Type, None]]:
_check_fill_arg(fill)
if isinstance(fill, dict):
return fill
return defaultdict(lambda: fill) # type: ignore[return-value]
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
def query_bounding_box(sample: Any) -> features.BoundingBox: def query_bounding_box(sample: Any) -> features.BoundingBox:
flat_sample, _ = tree_flatten(sample) flat_sample, _ = tree_flatten(sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment