Unverified Commit edde8255 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Allow catch-all 'others' key in fill dicts. Avoid need for defaultdict. (#7779)

parent 312c3d32
...@@ -10,7 +10,6 @@ well as the new ``torchvision.transforms.v2`` v2 API. ...@@ -10,7 +10,6 @@ well as the new ``torchvision.transforms.v2`` v2 API.
""" """
import pathlib import pathlib
from collections import defaultdict
import PIL.Image import PIL.Image
...@@ -99,9 +98,7 @@ show(sample) ...@@ -99,9 +98,7 @@ show(sample)
transform = transforms.Compose( transform = transforms.Compose(
[ [
transforms.RandomPhotometricDistort(), transforms.RandomPhotometricDistort(),
transforms.RandomZoomOut( transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}),
fill=defaultdict(lambda: 0, {PIL.Image.Image: (123, 117, 104)})
),
transforms.RandomIoUCrop(), transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ToImageTensor(), transforms.ToImageTensor(),
......
from collections import defaultdict
import torch import torch
...@@ -48,7 +46,7 @@ class SegmentationPresetTrain: ...@@ -48,7 +46,7 @@ class SegmentationPresetTrain:
if use_v2: if use_v2:
# We need a custom pad transform here, since the padding we want to perform here is fundamentally # We need a custom pad transform here, since the padding we want to perform here is fundamentally
# different from the padding in `RandomCrop` if `pad_if_needed=True`. # different from the padding in `RandomCrop` if `pad_if_needed=True`.
transforms += [v2_extras.PadIfSmaller(crop_size, fill=defaultdict(lambda: 0, {datapoints.Mask: 255}))] transforms += [v2_extras.PadIfSmaller(crop_size, fill={datapoints.Mask: 255, "others": 0})]
transforms += [T.RandomCrop(crop_size)] transforms += [T.RandomCrop(crop_size)]
......
...@@ -8,7 +8,7 @@ class PadIfSmaller(v2.Transform): ...@@ -8,7 +8,7 @@ class PadIfSmaller(v2.Transform):
def __init__(self, size, fill=0): def __init__(self, size, fill=0):
super().__init__() super().__init__()
self.size = size self.size = size
self.fill = v2._geometry._setup_fill_arg(fill) self.fill = v2._utils._setup_fill_arg(fill)
def _get_params(self, sample): def _get_params(self, sample):
_, height, width = v2.utils.query_chw(sample) _, height, width = v2.utils.query_chw(sample)
...@@ -20,7 +20,7 @@ class PadIfSmaller(v2.Transform): ...@@ -20,7 +20,7 @@ class PadIfSmaller(v2.Transform):
if not params["needs_padding"]: if not params["needs_padding"]:
return inpt return inpt
fill = self.fill[type(inpt)] fill = v2._utils._get_fill(self.fill, type(inpt))
fill = v2._utils._convert_fill_arg(fill) fill = v2._utils._convert_fill_arg(fill)
return v2.functional.pad(inpt, padding=params["padding"], fill=fill) return v2.functional.pad(inpt, padding=params["padding"], fill=fill)
......
...@@ -3,7 +3,6 @@ import pathlib ...@@ -3,7 +3,6 @@ import pathlib
import random import random
import textwrap import textwrap
import warnings import warnings
from collections import defaultdict
import numpy as np import numpy as np
...@@ -1475,7 +1474,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1475,7 +1474,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
elif data_augmentation == "ssd": elif data_augmentation == "ssd":
t = [ t = [
transforms.RandomPhotometricDistort(p=1), transforms.RandomPhotometricDistort(p=1),
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0}), p=1), transforms.RandomZoomOut(fill={"others": (123.0, 117.0, 104.0), datapoints.Mask: 0}, p=1),
transforms.RandomIoUCrop(), transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1), transforms.RandomHorizontalFlip(p=1),
to_tensor, to_tensor,
......
...@@ -4,7 +4,6 @@ import importlib.util ...@@ -4,7 +4,6 @@ import importlib.util
import inspect import inspect
import random import random
import re import re
from collections import defaultdict
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -30,6 +29,7 @@ from torchvision._utils import sequence_to_str ...@@ -30,6 +29,7 @@ from torchvision._utils import sequence_to_str
from torchvision.transforms import functional as legacy_F from torchvision.transforms import functional as legacy_F
from torchvision.transforms.v2 import functional as prototype_F from torchvision.transforms.v2 import functional as prototype_F
from torchvision.transforms.v2._utils import _get_fill
from torchvision.transforms.v2.functional import to_image_pil from torchvision.transforms.v2.functional import to_image_pil
from torchvision.transforms.v2.utils import query_size from torchvision.transforms.v2.utils import query_size
...@@ -1181,7 +1181,7 @@ class PadIfSmaller(v2_transforms.Transform): ...@@ -1181,7 +1181,7 @@ class PadIfSmaller(v2_transforms.Transform):
if not params["needs_padding"]: if not params["needs_padding"]:
return inpt return inpt
fill = self.fill[type(inpt)] fill = _get_fill(self.fill, type(inpt))
return prototype_F.pad(inpt, padding=params["padding"], fill=fill) return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
...@@ -1243,7 +1243,7 @@ class TestRefSegTransforms: ...@@ -1243,7 +1243,7 @@ class TestRefSegTransforms:
seg_transforms.RandomCrop(size=480), seg_transforms.RandomCrop(size=480),
v2_transforms.Compose( v2_transforms.Compose(
[ [
PadIfSmaller(size=480, fill=defaultdict(lambda: 0, {datapoints.Mask: 255})), PadIfSmaller(size=480, fill={datapoints.Mask: 255, "others": 0}),
v2_transforms.RandomCrop(size=480), v2_transforms.RandomCrop(size=480),
] ]
), ),
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.prototype.datapoints import Label, OneHotLabel
from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2._utils import _setup_fill_arg, _setup_size from torchvision.transforms.v2._utils import _get_fill, _setup_fill_arg, _setup_size
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_bounding_boxes, query_size
...@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform): ...@@ -14,7 +14,7 @@ class FixedSizeCrop(Transform):
def __init__( def __init__(
self, self,
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -119,7 +119,7 @@ class FixedSizeCrop(Transform): ...@@ -119,7 +119,7 @@ class FixedSizeCrop(Transform):
) )
if params["needs_pad"]: if params["needs_pad"]:
fill = self._fill[type(inpt)] fill = _get_fill(self._fill, type(inpt))
inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode)
return inpt return inpt
import functools
import warnings import warnings
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union from collections import defaultdict
from typing import Any, Dict, Optional, Sequence, Tuple, Type, TypeVar, Union
import torch import torch
from torchvision import datapoints from torchvision import datapoints
from torchvision.transforms.v2 import Transform from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2._utils import _get_defaultdict
from torchvision.transforms.v2.utils import is_simple_tensor from torchvision.transforms.v2.utils import is_simple_tensor
T = TypeVar("T")
def _default_arg(value: T) -> T:
return value
def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))
class PermuteDimensions(Transform): class PermuteDimensions(Transform):
_transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video)
......
...@@ -11,7 +11,7 @@ from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, Interp ...@@ -11,7 +11,7 @@ from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, Interp
from torchvision.transforms.v2.functional._geometry import _check_interpolation from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.functional._meta import get_size from torchvision.transforms.v2.functional._meta import get_size
from ._utils import _setup_fill_arg from ._utils import _get_fill, _setup_fill_arg
from .utils import check_type, is_simple_tensor from .utils import check_type, is_simple_tensor
...@@ -20,7 +20,7 @@ class _AutoAugmentBase(Transform): ...@@ -20,7 +20,7 @@ class _AutoAugmentBase(Transform):
self, self,
*, *,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.interpolation = _check_interpolation(interpolation) self.interpolation = _check_interpolation(interpolation)
...@@ -80,9 +80,9 @@ class _AutoAugmentBase(Transform): ...@@ -80,9 +80,9 @@ class _AutoAugmentBase(Transform):
transform_id: str, transform_id: str,
magnitude: float, magnitude: float,
interpolation: Union[InterpolationMode, int], interpolation: Union[InterpolationMode, int],
fill: Dict[Type, datapoints._FillTypeJIT], fill: Dict[Union[Type, str], datapoints._FillTypeJIT],
) -> Union[datapoints._ImageType, datapoints._VideoType]: ) -> Union[datapoints._ImageType, datapoints._VideoType]:
fill_ = fill[type(image)] fill_ = _get_fill(fill, type(image))
if transform_id == "Identity": if transform_id == "Identity":
return image return image
...@@ -214,7 +214,7 @@ class AutoAugment(_AutoAugmentBase): ...@@ -214,7 +214,7 @@ class AutoAugment(_AutoAugmentBase):
self, self,
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.policy = policy self.policy = policy
...@@ -394,7 +394,7 @@ class RandAugment(_AutoAugmentBase): ...@@ -394,7 +394,7 @@ class RandAugment(_AutoAugmentBase):
magnitude: int = 9, magnitude: int = 9,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_ops = num_ops self.num_ops = num_ops
...@@ -467,7 +467,7 @@ class TrivialAugmentWide(_AutoAugmentBase): ...@@ -467,7 +467,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
self, self,
num_magnitude_bins: int = 31, num_magnitude_bins: int = 31,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
): ):
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins self.num_magnitude_bins = num_magnitude_bins
...@@ -550,7 +550,7 @@ class AugMix(_AutoAugmentBase): ...@@ -550,7 +550,7 @@ class AugMix(_AutoAugmentBase):
alpha: float = 1.0, alpha: float = 1.0,
all_ops: bool = True, all_ops: bool = True,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = None,
) -> None: ) -> None:
super().__init__(interpolation=interpolation, fill=fill) super().__init__(interpolation=interpolation, fill=fill)
self._PARAMETER_MAX = 10 self._PARAMETER_MAX = 10
......
...@@ -17,6 +17,7 @@ from ._utils import ( ...@@ -17,6 +17,7 @@ from ._utils import (
_check_padding_arg, _check_padding_arg,
_check_padding_mode_arg, _check_padding_mode_arg,
_check_sequence_input, _check_sequence_input,
_get_fill,
_setup_angle, _setup_angle,
_setup_fill_arg, _setup_fill_arg,
_setup_float_or_seq, _setup_float_or_seq,
...@@ -487,7 +488,7 @@ class Pad(Transform): ...@@ -487,7 +488,7 @@ class Pad(Transform):
def __init__( def __init__(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -504,7 +505,7 @@ class Pad(Transform): ...@@ -504,7 +505,7 @@ class Pad(Transform):
self.padding_mode = padding_mode self.padding_mode = padding_mode
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)] fill = _get_fill(self._fill, type(inpt))
return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
...@@ -542,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -542,7 +543,7 @@ class RandomZoomOut(_RandomApplyTransform):
def __init__( def __init__(
self, self,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
side_range: Sequence[float] = (1.0, 4.0), side_range: Sequence[float] = (1.0, 4.0),
p: float = 0.5, p: float = 0.5,
) -> None: ) -> None:
...@@ -574,7 +575,7 @@ class RandomZoomOut(_RandomApplyTransform): ...@@ -574,7 +575,7 @@ class RandomZoomOut(_RandomApplyTransform):
return dict(padding=padding) return dict(padding=padding)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)] fill = _get_fill(self._fill, type(inpt))
return F.pad(inpt, **params, fill=fill) return F.pad(inpt, **params, fill=fill)
...@@ -620,7 +621,7 @@ class RandomRotation(Transform): ...@@ -620,7 +621,7 @@ class RandomRotation(Transform):
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,))
...@@ -640,7 +641,7 @@ class RandomRotation(Transform): ...@@ -640,7 +641,7 @@ class RandomRotation(Transform):
return dict(angle=angle) return dict(angle=angle)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)] fill = _get_fill(self._fill, type(inpt))
return F.rotate( return F.rotate(
inpt, inpt,
**params, **params,
...@@ -702,7 +703,7 @@ class RandomAffine(Transform): ...@@ -702,7 +703,7 @@ class RandomAffine(Transform):
scale: Optional[Sequence[float]] = None, scale: Optional[Sequence[float]] = None,
shear: Optional[Union[int, float, Sequence[float]]] = None, shear: Optional[Union[int, float, Sequence[float]]] = None,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -762,7 +763,7 @@ class RandomAffine(Transform): ...@@ -762,7 +763,7 @@ class RandomAffine(Transform):
return dict(angle=angle, translate=translate, scale=scale, shear=shear) return dict(angle=angle, translate=translate, scale=scale, shear=shear)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)] fill = _get_fill(self._fill, type(inpt))
return F.affine( return F.affine(
inpt, inpt,
**params, **params,
...@@ -840,7 +841,7 @@ class RandomCrop(Transform): ...@@ -840,7 +841,7 @@ class RandomCrop(Transform):
size: Union[int, Sequence[int]], size: Union[int, Sequence[int]],
padding: Optional[Union[int, Sequence[int]]] = None, padding: Optional[Union[int, Sequence[int]]] = None,
pad_if_needed: bool = False, pad_if_needed: bool = False,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -918,7 +919,7 @@ class RandomCrop(Transform): ...@@ -918,7 +919,7 @@ class RandomCrop(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_pad"]: if params["needs_pad"]:
fill = self._fill[type(inpt)] fill = _get_fill(self._fill, type(inpt))
inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode)
if params["needs_crop"]: if params["needs_crop"]:
...@@ -959,7 +960,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -959,7 +960,7 @@ class RandomPerspective(_RandomApplyTransform):
distortion_scale: float = 0.5, distortion_scale: float = 0.5,
p: float = 0.5, p: float = 0.5,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None: ) -> None:
super().__init__(p=p) super().__init__(p=p)
...@@ -1002,7 +1003,7 @@ class RandomPerspective(_RandomApplyTransform): ...@@ -1002,7 +1003,7 @@ class RandomPerspective(_RandomApplyTransform):
return dict(coefficients=perspective_coeffs) return dict(coefficients=perspective_coeffs)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)] fill = _get_fill(self._fill, type(inpt))
return F.perspective( return F.perspective(
inpt, inpt,
None, None,
...@@ -1061,7 +1062,7 @@ class ElasticTransform(Transform): ...@@ -1061,7 +1062,7 @@ class ElasticTransform(Transform):
alpha: Union[float, Sequence[float]] = 50.0, alpha: Union[float, Sequence[float]] = 50.0,
sigma: Union[float, Sequence[float]] = 5.0, sigma: Union[float, Sequence[float]] = 5.0,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = 0, fill: Union[datapoints._FillType, Dict[Union[Type, str], datapoints._FillType]] = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2) self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
...@@ -1095,7 +1096,7 @@ class ElasticTransform(Transform): ...@@ -1095,7 +1096,7 @@ class ElasticTransform(Transform):
return dict(displacement=displacement) return dict(displacement=displacement)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = self._fill[type(inpt)] fill = _get_fill(self._fill, type(inpt))
return F.elastic( return F.elastic(
inpt, inpt,
**params, **params,
......
import collections.abc import collections.abc
import functools
import numbers import numbers
from collections import defaultdict
from contextlib import suppress from contextlib import suppress
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, TypeVar, Union from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, Union
import torch import torch
...@@ -29,32 +27,15 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: ...@@ -29,32 +27,15 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size:
return arg return arg
def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None: def _check_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> None:
if isinstance(fill, dict): if isinstance(fill, dict):
for key, value in fill.items(): for value in fill.values():
# Check key for type
_check_fill_arg(value) _check_fill_arg(value)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
_check_fill_arg(default_value)
else: else:
if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)): if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.") raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")
T = TypeVar("T")
def _default_arg(value: T) -> T:
return value
def _get_defaultdict(default: T) -> Dict[Any, T]:
# This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
# If it were possible, we could replace this with `defaultdict(lambda: default)`
return defaultdict(functools.partial(_default_arg, default))
def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT: def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
# So, we can't reassign fill to 0 # So, we can't reassign fill to 0
...@@ -68,19 +49,24 @@ def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT: ...@@ -68,19 +49,24 @@ def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
return fill # type: ignore[return-value] return fill # type: ignore[return-value]
def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]: def _setup_fill_arg(fill: Union[_FillType, Dict[Union[Type, str], _FillType]]) -> Dict[Union[Type, str], _FillTypeJIT]:
_check_fill_arg(fill) _check_fill_arg(fill)
if isinstance(fill, dict): if isinstance(fill, dict):
for k, v in fill.items(): for k, v in fill.items():
fill[k] = _convert_fill_arg(v) fill[k] = _convert_fill_arg(v)
if isinstance(fill, defaultdict) and callable(fill.default_factory):
default_value = fill.default_factory()
sanitized_default = _convert_fill_arg(default_value)
fill.default_factory = functools.partial(_default_arg, sanitized_default)
return fill # type: ignore[return-value] return fill # type: ignore[return-value]
else:
return {"others": _convert_fill_arg(fill)}
return _get_defaultdict(_convert_fill_arg(fill))
def _get_fill(fill_dict, inpt_type):
if inpt_type in fill_dict:
return fill_dict[inpt_type]
elif "others" in fill_dict:
return fill_dict["others"]
else:
RuntimeError("This should never happen, please open an issue on the torchvision repo if you hit this.")
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment