"tests/vscode:/vscode.git/clone" did not exist on "896ecdd6df44deaf208b529c483b5bc8a118af45"
_utils.py 6.4 KB
Newer Older
1
import collections.abc
2
import functools
vfdev's avatar
vfdev committed
3
4
import numbers
from collections import defaultdict
5
6
7
8
from contextlib import suppress
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Type, TypeVar, Union

import torch
9

10
from torchvision import datapoints
Philip Meier's avatar
Philip Meier committed
11
from torchvision.datapoints._datapoint import _FillType, _FillTypeJIT
12
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size  # noqa: F401
13

vfdev's avatar
vfdev committed
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
    if not isinstance(arg, (float, Sequence)):
        raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}")
    if isinstance(arg, Sequence) and len(arg) != req_size:
        raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}")
    if isinstance(arg, Sequence):
        for element in arg:
            if not isinstance(element, float):
                raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")

    if isinstance(arg, float):
        arg = [float(arg), float(arg)]
    if isinstance(arg, (list, tuple)) and len(arg) == 1:
        arg = [arg[0], arg[0]]
    return arg


Philip Meier's avatar
Philip Meier committed
32
def _check_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> None:
vfdev's avatar
vfdev committed
33
34
35
36
    if isinstance(fill, dict):
        for key, value in fill.items():
            # Check key for type
            _check_fill_arg(value)
37
38
39
        if isinstance(fill, defaultdict) and callable(fill.default_factory):
            default_value = fill.default_factory()
            _check_fill_arg(default_value)
vfdev's avatar
vfdev committed
40
41
    else:
        if fill is not None and not isinstance(fill, (numbers.Number, tuple, list)):
42
            raise TypeError("Got inappropriate fill arg, only Numbers, tuples, lists and dicts are allowed.")
vfdev's avatar
vfdev committed
43
44


45
46
47
48
49
50
51
52
53
54
55
T = TypeVar("T")


def _default_arg(value: T) -> T:
    return value


def _get_defaultdict(default: T) -> Dict[Any, T]:
    # This weird looking construct only exists, since `lambda`'s cannot be serialized by pickle.
    # If it were possible, we could replace this with `defaultdict(lambda: default)`
    return defaultdict(functools.partial(_default_arg, default))
56
57


Philip Meier's avatar
Philip Meier committed
58
def _convert_fill_arg(fill: datapoints._FillType) -> datapoints._FillTypeJIT:
59
60
61
62
63
64
65
66
67
    # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
    # So, we can't reassign fill to 0
    # if fill is None:
    #     fill = 0
    if fill is None:
        return fill

    if not isinstance(fill, (int, float)):
        fill = [float(v) for v in list(fill)]
68
    return fill  # type: ignore[return-value]
69
70


Philip Meier's avatar
Philip Meier committed
71
def _setup_fill_arg(fill: Union[_FillType, Dict[Type, _FillType]]) -> Dict[Type, _FillTypeJIT]:
vfdev's avatar
vfdev committed
72
73
74
    _check_fill_arg(fill)

    if isinstance(fill, dict):
75
76
77
78
79
80
81
        for k, v in fill.items():
            fill[k] = _convert_fill_arg(v)
        if isinstance(fill, defaultdict) and callable(fill.default_factory):
            default_value = fill.default_factory()
            sanitized_default = _convert_fill_arg(default_value)
            fill.default_factory = functools.partial(_default_arg, sanitized_default)
        return fill  # type: ignore[return-value]
vfdev's avatar
vfdev committed
82

83
    return _get_defaultdict(_convert_fill_arg(fill))
vfdev's avatar
vfdev committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98


def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
    if not isinstance(padding, (numbers.Number, tuple, list)):
        raise TypeError("Got inappropriate padding arg")

    if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
        raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")


# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)
# https://github.com/pytorch/vision/issues/6250
def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None:
    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155


def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
    """
    This heuristic covers three cases:

    1. The input is tuple or list whose second item is a labels tensor. This happens for already batched
       classification inputs for Mixup and Cutmix (typically after the Dataloder).
    2. The input is a tuple or list whose second item is a dictionary that contains the labels tensor
       under a label-like (see below) key. This happens for the inputs of detection models.
    3. The input is a dictionary that is structured as the one from 2.

    What is "label-like" key? We first search for an case-insensitive match of 'labels' inside the keys of the
    dictionary. This is the name our detection models expect. If we can't find that, we look for a case-insensitive
    match of the term 'label' anywhere inside the key, i.e. 'FooLaBeLBar'. If we can't find that either, the dictionary
    contains no "label-like" key.
    """

    if isinstance(inputs, (tuple, list)):
        inputs = inputs[1]

    # Mixup, Cutmix
    if isinstance(inputs, torch.Tensor):
        return inputs

    if not isinstance(inputs, collections.abc.Mapping):
        raise ValueError(
            f"When using the default labels_getter, the input passed to forward must be a dictionary or a two-tuple "
            f"whose second item is a dictionary or a tensor, but got {inputs} instead."
        )

    candidate_key = None
    with suppress(StopIteration):
        candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
    if candidate_key is None:
        with suppress(StopIteration):
            candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
    if candidate_key is None:
        raise ValueError(
            "Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
            "If there are no labels in the sample by design, pass labels_getter=None."
        )

    return inputs[candidate_key]


def _parse_labels_getter(
    labels_getter: Union[str, Callable[[Any], Optional[torch.Tensor]], None]
) -> Callable[[Any], Optional[torch.Tensor]]:
    if labels_getter == "default":
        return _find_labels_default_heuristic
    elif callable(labels_getter):
        return labels_getter
    elif labels_getter is None:
        return lambda _: None
    else:
        raise ValueError(f"labels_getter should either be 'default', a callable, or None, but got {labels_getter}.")