_transform.py 8.3 KB
Newer Older
1
2
from __future__ import annotations

3
import enum
4
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
5

6
import PIL.Image
7
import torch
8
from torch import nn
9
from torch.utils._pytree import tree_flatten, tree_unflatten
10
11
from torchvision import datapoints
from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor
12
13
from torchvision.utils import _log_api_usage_once

14
15
from .functional._utils import _get_kernel

16
17

class Transform(nn.Module):
18
19

    # Class attribute defining transformed types. Other types are passed-through without any transformation
20
21
    # We support both Types and callables that are able to do further checks on the type of the input.
    _transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
22

23
24
25
26
    def __init__(self) -> None:
        super().__init__()
        _log_api_usage_once(self)

27
    def _check_inputs(self, flat_inputs: List[Any]) -> None:
28
29
        pass

30
    def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
31
32
        return dict()

Nicolas Hug's avatar
Nicolas Hug committed
33
34
    def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
        kernel = _get_kernel(functional, type(inpt), allow_passthrough=True)
35
36
        return kernel(inpt, *args, **kwargs)

37
    def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
38
        raise NotImplementedError
39

40
    def forward(self, *inputs: Any) -> Any:
41
        flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
42

43
        self._check_inputs(flat_inputs)
44

45
46
47
48
        needs_transform_list = self._needs_transform_list(flat_inputs)
        params = self._get_params(
            [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
        )
49

50
51
52
53
54
55
56
57
        flat_outputs = [
            self._transform(inpt, params) if needs_transform else inpt
            for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
        ]

        return tree_unflatten(flat_outputs, spec)

    def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        # Below is a heuristic on how to deal with simple tensor inputs:
        # 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image
        #    (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample.
        # 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is
        #    transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
        #    of `tree_flatten`, which recurses depth-first through the input.
        #
        # This heuristic stems from two requirements:
        # 1. We need to keep BC for single input simple tensors and treat them as images.
        # 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface`
        #    return supplemental numerical data as tensors that cannot be transformed as images.
        #
        # The heuristic should work well for most people in practice. The only case where it doesn't is if someone
        # tries to transform multiple simple tensors at the same time, expecting them all to be treated as images.
        # However, this case wasn't supported by transforms v1 either, so there is no BC concern.
73
74

        needs_transform_list = []
75
76
77
78
79
80
81
82
83
84
85
        transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image)
        for inpt in flat_inputs:
            needs_transform = True

            if not check_type(inpt, self._transformed_types):
                needs_transform = False
            elif is_simple_tensor(inpt):
                if transform_simple_tensor:
                    transform_simple_tensor = False
                else:
                    needs_transform = False
86
87
            needs_transform_list.append(needs_transform)
        return needs_transform_list
88
89
90
91
92
93
94
95
96
97
98
99
100

    def extra_repr(self) -> str:
        extra = []
        for name, value in self.__dict__.items():
            if name.startswith("_") or name == "training":
                continue

            if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)):
                continue

            extra.append(f"{name}={value}")

        return ", ".join(extra)
101

102
103
104
105
106
    # This attribute should be set on all transforms that have a v1 equivalent. Doing so enables two things:
    # 1. In case the v1 transform has a static `get_params` method, it will also be available under the same name on
    #    the v2 transform. See `__init_subclass__` for details.
    # 2. The v2 transform will be JIT scriptable. See `_extract_params_for_v1_transform` and `__prepare_scriptable__`
    #    for details.
107
108
    _v1_transform_cls: Optional[Type[nn.Module]] = None

109
110
111
112
    def __init_subclass__(cls) -> None:
        # Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
        # This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
        if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
113
            cls.get_params = staticmethod(cls._v1_transform_cls.get_params)  # type: ignore[attr-defined]
114

115
116
    def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
        # This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
117
118
        # v2 transform instance. It extracts all available public attributes that are specific to that transform and
        # not `nn.Module` in general.
119
120
121
        # Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
        # if the v2 transform introduced new parameters that are not support by the v1 transform.
        common_attrs = nn.Module().__dict__.keys()
122
        return {
123
124
125
126
127
128
129
130
131
132
133
134
135
            attr: value
            for attr, value in self.__dict__.items()
            if not attr.startswith("_") and attr not in common_attrs
        }

    def __prepare_scriptable__(self) -> nn.Module:
        # This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
        # value is used for scripting over the original object that should have been scripted. Since the v1 transforms
        # are JIT scriptable, and we made sure that for single image inputs v1 and v2 are equivalent, we just return the
        # equivalent v1 transform here. This of course only makes transforms v2 JIT scriptable as long as transforms v1
        # is around.
        if self._v1_transform_cls is None:
            raise RuntimeError(
136
                f"Transform {type(self).__name__} cannot be JIT scripted. "
137
138
139
                "torchscript is only supported for backward compatibility with transforms "
                "which are already in torchvision.transforms. "
                "For torchscript support (on tensors only), you can use the functional API instead."
140
141
142
143
            )

        return self._v1_transform_cls(**self._extract_params_for_v1_transform())

144
145

class _RandomApplyTransform(Transform):
146
    def __init__(self, p: float = 0.5) -> None:
147
148
149
150
151
152
153
        if not (0.0 <= p <= 1.0):
            raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")

        super().__init__()
        self.p = p

    def forward(self, *inputs: Any) -> Any:
154
155
156
157
        # We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return
        # early afterwards in case the random check triggers. The same result could be achieved by calling
        # `super().forward()` after the random check, but that would call `self._check_inputs` twice.

158
159
        inputs = inputs if len(inputs) > 1 else inputs[0]
        flat_inputs, spec = tree_flatten(inputs)
160

161
        self._check_inputs(flat_inputs)
162

163
        if torch.rand(1) >= self.p:
164
            return inputs
165

166
167
168
169
        needs_transform_list = self._needs_transform_list(flat_inputs)
        params = self._get_params(
            [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
        )
170
171

        flat_outputs = [
172
173
            self._transform(inpt, params) if needs_transform else inpt
            for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
174
        ]
175

176
        return tree_unflatten(flat_outputs, spec)