_transform.py 8.83 KB
Newer Older
1
2
from __future__ import annotations

3
import enum
4
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
5

6
import PIL.Image
7
import torch
8
from torch import nn
9
from torch.utils._pytree import tree_flatten, tree_unflatten
10
11
from torchvision import datapoints
from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor
12
13
14
15
from torchvision.utils import _log_api_usage_once


class Transform(nn.Module):
16
17

    # Class attribute defining transformed types. Other types are passed-through without any transformation
18
19
    # We support both Types and callables that are able to do further checks on the type of the input.
    _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
20

21
22
23
24
    def __init__(self) -> None:
        super().__init__()
        _log_api_usage_once(self)

25
    def _check_inputs(self, flat_inputs: List[Any]) -> None:
26
27
        pass

28
    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
29
30
        return dict()

31
    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
32
        raise NotImplementedError
33

34
    def forward(self, *inputs: Any) -> Any:
35
        flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
36

37
        self._check_inputs(flat_inputs)
38

39
40
41
42
        needs_transform_list = self._needs_transform_list(flat_inputs)
        params = self._get_params(
            [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
        )
43

44
45
46
47
48
49
50
51
        flat_outputs = [
            self._transform(inpt, params) if needs_transform else inpt
            for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
        ]

        return tree_unflatten(flat_outputs, spec)

    def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        # Below is a heuristic on how to deal with simple tensor inputs:
        # 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
        #    (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
        # 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is
        #    transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
        #    of `tree_flatten`, which recurses depth-first through the input.
        #
        # This heuristic stems from two requirements:
        # 1. We need to keep BC for single input simple tensors and treat them as images.
        # 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface`
        #    return supplemental numerical data as tensors that cannot be transformed as images.
        #
        # The heuristic should work well for most people in practice. The only case where it doesn't is if someone
        # tries to transform multiple simple tensors at the same time, expecting them all to be treated as images.
        # However, this case wasn't supported by transforms v1 either, so there is no BC concern.
67
68

        needs_transform_list = []
69
70
71
72
73
74
75
76
77
78
79
        transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
        for inpt in flat_inputs:
            needs_transform = True

            if not check_type(inpt, self._transformed_types):
                needs_transform = False
            elif is_simple_tensor(inpt):
                if transform_simple_tensor:
                    transform_simple_tensor = False
                else:
                    needs_transform = False
80
81
            needs_transform_list.append(needs_transform)
        return needs_transform_list
82
83
84
85
86
87
88
89
90
91
92
93
94

    def extra_repr(self) -> str:
        extra = []
        for name, value in self.__dict__.items():
            if name.startswith("_") or name == "training":
                continue

            if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)):
                continue

            extra.append(f"{name}={value}")

        return ", ".join(extra)
95

96
97
98
99
100
    # This attribute should be set on all transforms that have a v1 equivalent. Doing so enables two things:
    # 1. In case the v1 transform has a static `get_params` method, it will also be available under the same name on
    #    the v2 transform. See `__init_subclass__` for details.
    # 2. The v2 transform will be JIT scriptable. See `_extract_params_for_v1_transform` and `__prepare_scriptable__`
    #    for details.
101
102
    _v1_transform_cls: Optional[Type[nn.Module]] = None

103
104
105
106
    def __init_subclass__(cls) -> None:
        # Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
        # This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
        if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
107
            cls.get_params = staticmethod(cls._v1_transform_cls.get_params)  # type: ignore[attr-defined]
108

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
        # This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
        # v2 transform instance. It does two things:
        # 1. Extract all available public attributes that are specific to that transform and not `nn.Module` in general
        # 2. If available handle the `fill` attribute for v1 compatibility (see below for details)
        # Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
        # if the v2 transform introduced new parameters that are not support by the v1 transform.
        common_attrs = nn.Module().__dict__.keys()
        params = {
            attr: value
            for attr, value in self.__dict__.items()
            if not attr.startswith("_") and attr not in common_attrs
        }

        # transforms v2 has a more complex handling for the `fill` parameter than v1. By default, the input is parsed
        # with `prototype.transforms._utils._setup_fill_arg()`, which returns a defaultdict that holds the fill value
        # for the different datapoint types. Below we extract the value for tensors and return that together with the
        # other params.
        # This is needed for `Pad`, `ElasticTransform`, `RandomAffine`, `RandomCrop`, `RandomPerspective` and
        # `RandomRotation`
        if "fill" in params:
            fill_type_defaultdict = params.pop("fill")
            params["fill"] = fill_type_defaultdict[torch.Tensor]

        return params

    def __prepare_scriptable__(self) -> nn.Module:
        # This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
        # value is used for scripting over the original object that should have been scripted. Since the v1 transforms
        # are JIT scriptable, and we made sure that for single image inputs v1 and v2 are equivalent, we just return the
        # equivalent v1 transform here. This of course only makes transforms v2 JIT scriptable as long as transforms v1
        # is around.
        if self._v1_transform_cls is None:
            raise RuntimeError(
143
                f"Transform {type(self).__name__} cannot be JIT scripted. "
144
145
146
                "torchscript is only supported for backward compatibility with transforms "
                "which are already in torchvision.transforms. "
                "For torchscript support (on tensors only), you can use the functional API instead."
147
148
149
150
            )

        return self._v1_transform_cls(**self._extract_params_for_v1_transform())

151
152

class _RandomApplyTransform(Transform):
153
    def __init__(self, p: float = 0.5) -> None:
154
155
156
157
158
159
160
        if not (0.0 <= p <= 1.0):
            raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")

        super().__init__()
        self.p = p

    def forward(self, *inputs: Any) -> Any:
161
162
163
164
        # We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return
        # early afterwards in case the random check triggers. The same result could be achieved by calling
        # `super().forward()` after the random check, but that would call `self._check_inputs` twice.

165
166
        inputs = inputs if len(inputs) > 1 else inputs[0]
        flat_inputs, spec = tree_flatten(inputs)
167

168
        self._check_inputs(flat_inputs)
169

170
        if torch.rand(1) >= self.p:
171
            return inputs
172

173
174
175
176
        needs_transform_list = self._needs_transform_list(flat_inputs)
        params = self._get_params(
            [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
        )
177
178

        flat_outputs = [
179
180
            self._transform(inpt, params) if needs_transform else inpt
            for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
181
        ]
182

183
        return tree_unflatten(flat_outputs, spec)