_utils.py 3.87 KB
Newer Older
1
import functools
vfdev's avatar
vfdev committed
2
3
import numbers
from collections import defaultdict
4
from typing import Any, Dict, Literal, Sequence, Type, TypeVar, Union
5

6
from torchvision import datapoints
Philip Meier's avatar
Philip Meier committed
7
from torchvision.datapoints._datapoint import _FillType, _FillTypeJIT
8

9
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size  # noqa: F401
10

vfdev's avatar
vfdev committed
11

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
    if not isinstance(arg, (float, Sequence)):
        raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}")
    if isinstance(arg, Sequence) and len(arg) != req_size:
        raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}")
    if isinstance(arg, Sequence):
        for element in arg:
            if not isinstance(element, float):
                raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")

    if isinstance(arg, float):
        arg = [float(arg), float(arg)]
    if isinstance(arg, (list, tuple)) and len(arg) == 1:
        arg = [arg[0], arg[0]]
    return arg


Philip Meier's avatar
Philip Meier committed
29
def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None:
vfdev's avatar
vfdev committed
30
31
32
33
    if isinstance(fill, dict):
        for key, value in fill.items():
            # Check key for type
            _check_fill_arg(value)
34
35
36
        if isinstance(fill, defaultdict) and callable(fill.default_factory):
            default_value = fill.default_factory()
            _check_fill_arg(default_value)
vfdev's avatar
vfdev committed
37
38
    else:
        if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
39
            raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")
vfdev's avatar
vfdev committed
40
41


42
43
44
45
46
47
48
49
50
51
52
T = TypeVar("T")


def _default_arg(value: T) -> T:
    return value


def _get_defaultdict(default: T) -> Dict[Any, T]:
    # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
    # If it were possible, we could replace this with `defaultdict(lambda: default)`
    return defaultdict(functools.partial(_default_arg, default))
53
54


Philip Meier's avatar
Philip Meier committed
55
def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
56
57
58
59
60
61
62
63
64
    # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
    # So, we can't reassign fill to 0
    # if fill is None:
    #     fill = 0
    if fill is None:
        return fill

    if not isinstance(fill, (int, float)):
        fill = [float(v) for v in list(fill)]
65
    return fill  # type: ignore[return-value]
66
67


Philip Meier's avatar
Philip Meier committed
68
def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]:
vfdev's avatar
vfdev committed
69
70
71
    _check_fill_arg(fill)

    if isinstance(fill, dict):
72
73
74
75
76
77
78
        for k, v in fill.items():
            fill[k] = _convert_fill_arg(v)
        if isinstance(fill, defaultdict) and callable(fill.default_factory):
            default_value = fill.default_factory()
            sanitized_default = _convert_fill_arg(default_value)
            fill.default_factory = functools.partial(_default_arg, sanitized_default)
        return fill  # type: ignore[return-value]
vfdev's avatar
vfdev committed
79

80
    return _get_defaultdict(_convert_fill_arg(fill))
vfdev's avatar
vfdev committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95


def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
    if not isinstance(padding, (numbers.Number, tuple, list)):
        raise TypeError("Got inappropriate padding arg")

    if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
        raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")


# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")