Unverified Commit edde8255 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Allow catch-all 'others' key in fill dicts. Avoid need for defaultdict. (#7779)

parent 312c3d32
......@@ -10,7 +10,6 @@ well as the new ``torchvision.transforms.v2`` v2 API.
"""
import pathlib
from collections import defaultdict
import PIL.Image
......@@ -99,9 +98,7 @@ show(sample)
transform = transforms.Compose(
[
transforms.RandomPhotometricDistort(),
transforms.RandomZoomOut(
fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)})
),
transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}),
transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(),
transforms.ToImageTensor(),
......
from collections import defaultdict
import torch
......@@ -48,7 +46,7 @@ class SegmentationPresetTrain:
if use_v2:
# We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`.
transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))]
transforms += [v2_extras.PadIfSmaller(crop_size, fill={datapoints.Mask: 255, "others": 0})]
transforms += [T.RandomCrop(crop_size)]
......
......@@ -8,7 +8,7 @@ class PadIfSmaller(v2.Transform):
def __init__(self, size, fill=0):
super().__init__()
self.size = size
self.fill = v2._geometry._setup_fill_arg(fill)
self.fill = v2._utils._setup_fill_arg(fill)
def _get_params(self, sample):
_, height, width = v2.utils.query_chw(sample)
......@@ -20,7 +20,7 @@ class PadIfSmaller(v2.Transform):
if not params["needs_padding"]:
return inpt
fill = self.fill[type(inpt)]
fill = v2._utils._get_fill(self.fill, type(inpt))
fill = v2._utils._convert_fill_arg(fill)
return v2.functional.pad(inpt, padding=params["padding"], fill=fill)
......
......@@ -3,7 +3,6 @@ import pathlib
import random
import textwrap
import warnings
from collections import defaultdict
import numpy as np
......@@ -1475,7 +1474,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
elif data_augmentation == "ssd":
t = [
transforms.RandomPhotometricDistort(p=1),
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}), p=1),
transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), datapoints.Mask: 0}, p=1),
transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1),
to_tensor,
......
......@@ -4,7 +4,6 @@ import importlib.util
import inspect
import random
import re
from collections import defaultdict
from pathlib import Path
import numpy as np
......@@ -30,6 +29,7 @@ from torchvision._utils import sequence_to_str
from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2._utils import _get_fill
from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2.utils import query_size
......@@ -1181,7 +1181,7 @@ class PadIfSmaller(v2_transforms.Transform):
if not params["needs_padding"]:
return inpt
fill = self.fill[type(inpt)]
fill = _get_fill(self.fill, type(inpt))
return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
......@@ -1243,7 +1243,7 @@ class TestRefSegTransforms:
seg_transforms.RandomCrop(size=480),
v2_transforms.Compose(
[
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})),
PadIfSmaller(size=480, fill={datapoints.Mask: 255, "others": 0}),
v2_transforms.RandomCrop(size=480),
]
),
......
......@@ -6,7 +6,7 @@ import torch
from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size
from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size
......@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
......@@ -119,7 +119,7 @@ class FixedSizeCrop(Transform):
)
if params["needs_pad"]:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt
import functools
import warnings
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union
from collections import defaultdict
from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union
import torch
from torchvision import datapoints
from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2._utils import _get_defaultdict
from torchvision.transforms.v2.utils import is_simple_tensor
T = TypeVar("T")
def _default_arg(value: T) -> T:
return value
def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))
class PermuteDimensions(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
......
......@@ -11,7 +11,7 @@ from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, Interp
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_size
from ._utils import _setup_fill_arg
from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_simple_tensor
......@@ -20,7 +20,7 @@ class _AutoAugmentBase(Transform):
self,
*,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__()
self.interpolation = _check_interpolation(interpolation)
......@@ -80,9 +80,9 @@ class _AutoAugmentBase(Transform):
transform_id: str,
magnitude: float,
interpolation: Union[InterpolationMode, int],
fill: Dict[Type, datapoints._FillTypeJIT],
fill: Dict[Union[Type, str], datapoints._FillTypeJIT],
) -> Union[datapoints._ImageType, datapoints._VideoType]:
fill_ = fill[type(image)]
fill_ = _get_fill(fill, type(image))
if transform_id == "Identity":
return image
......@@ -214,7 +214,7 @@ class AutoAugment(_AutoAugmentBase):
self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy
......@@ -394,7 +394,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops
......@@ -467,7 +467,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self,
num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
):
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins
......@@ -550,7 +550,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0,
all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10
......
......@@ -17,6 +17,7 @@ from ._utils import (
_check_padding_arg,
_check_padding_mode_arg,
_check_sequence_input,
_get_fill,
_setup_angle,
_setup_fill_arg,
_setup_float_or_seq,
......@@ -487,7 +488,7 @@ class Pad(Transform):
def __init__(
self,
padding: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
......@@ -504,7 +505,7 @@ class Pad(Transform):
self.padding_mode = padding_mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
......@@ -542,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform):
def __init__(
self,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5,
) -> None:
......@@ -574,7 +575,7 @@ class RandomZoomOut(_RandomApplyTransform):
return dict(padding=padding)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.pad(inpt, **params, fill=fill)
......@@ -620,7 +621,7 @@ class RandomRotation(Transform):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None:
super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
......@@ -640,7 +641,7 @@ class RandomRotation(Transform):
return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.rotate(
inpt,
**params,
......@@ -702,7 +703,7 @@ class RandomAffine(Transform):
scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
center: Optional[List[float]] = None,
) -> None:
super().__init__()
......@@ -762,7 +763,7 @@ class RandomAffine(Transform):
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.affine(
inpt,
**params,
......@@ -840,7 +841,7 @@ class RandomCrop(Transform):
size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None:
super().__init__()
......@@ -918,7 +919,7 @@ class RandomCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
if params["needs_crop"]:
......@@ -959,7 +960,7 @@ class RandomPerspective(_RandomApplyTransform):
distortion_scale: float = 0.5,
p: float = 0.5,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None:
super().__init__(p=p)
......@@ -1002,7 +1003,7 @@ class RandomPerspective(_RandomApplyTransform):
return dict(coefficients=perspective_coeffs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.perspective(
inpt,
None,
......@@ -1061,7 +1062,7 @@ class ElasticTransform(Transform):
alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0,
fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
......@@ -1095,7 +1096,7 @@ class ElasticTransform(Transform):
return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)]
fill = _get_fill(self._fill, type(inpt))
return F.elastic(
inpt,
**params,
......
import collections.abc
import functools
import numbers
from collections import defaultdict
from contextlib import suppress
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, TypeVar, Union
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union
import torch
......@@ -29,32 +27,15 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size:
return arg
def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None:
def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> None:
if isinstance(fill, dict):
for key, value in fill.items():
# Check key for type
for value in fill.values():
_check_fill_arg(value)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
_check_fill_arg(default_value)
else:
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")
T = TypeVar("T")
def _default_arg(value: T) -> T:
return value
def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))
def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0
......@@ -68,19 +49,24 @@ def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
return fill # type: ignore[return-value]
def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]:
def _setup_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> Dict[Union[Type, str], _FillTypeJIT]:
_check_fill_arg(fill)
if isinstance(fill, dict):
for k, v in fill.items():
fill[k] = _convert_fill_arg(v)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
sanitized_default = _convert_fill_arg(default_value)
fill.default_factory = functools.partial(_default_arg, sanitized_default)
return fill # type: ignore[return-value]
else:
return {"others": _convert_fill_arg(fill)}
return _get_defaultdict(_convert_fill_arg(fill))
def _get_fill(fill_dict, inpt_type):
if inpt_type in fill_dict:
return fill_dict[inpt_type]
elif "others" in fill_dict:
return fill_dict["others"]
else:
RuntimeError("This should never happen, please open an issue on the torchvision repo if you hit this.")
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment