_container.py 6.11 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
2
3

import torch
4
5

from torch import nn
6
from torchvision import transforms as _transforms
7
from torchvision.transforms.v2 import Transform
8
9
10


class Compose(Transform):
11
12
    """[BETA] Composes several transforms together.

13
    .. v2betastatus:: Compose transform
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

    This transform does not support torchscript.
    Please, see the note below.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
        >>> ])

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    """

42
    def __init__(self, transforms: Sequence[Callable]) -> None:
43
        super().__init__()
44
45
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
46
47
        elif not transforms:
            raise ValueError("Pass at least one transform")
48
49
        self.transforms = transforms

50
    def forward(self, *inputs: Any) -> Any:
51
        needs_unpacking = len(inputs) > 1
52
        for transform in self.transforms:
53
54
55
            outputs = transform(*inputs)
            inputs = outputs if needs_unpacking else (outputs,)
        return outputs
56

57
58
59
60
61
62
    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)

63

64
class RandomApply(Transform):
65
66
    """[BETA] Apply randomly a list of transformations with a given probability.

67
    .. v2betastatus:: RandomApply transform
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

    .. note::
        In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
        transforms as shown below:

        >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
        >>>     transforms.ColorJitter(),
        >>> ]), p=0.3)
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
83
        p (float): probability of applying the list of transforms
84
85
    """

86
87
    _v1_transform_cls = _transforms.RandomApply

88
89
90
91
92
93
    def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
        super().__init__()

        if not isinstance(transforms, (Sequence, nn.ModuleList)):
            raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
        self.transforms = transforms
94
95
96
97
98

        if not (0.0 <= p <= 1.0):
            raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
        self.p = p

99
100
101
    def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
        return {"transforms": self.transforms, "p": self.p}

102
    def forward(self, *inputs: Any) -> Any:
103
        needs_unpacking = len(inputs) > 1
104

105
        if torch.rand(1) >= self.p:
106
            return inputs if needs_unpacking else inputs[0]
107

108
        for transform in self.transforms:
109
110
111
            outputs = transform(*inputs)
            inputs = outputs if needs_unpacking else (outputs,)
        return outputs
112
113
114
115
116
117

    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)
118
119
120


class RandomChoice(Transform):
121
122
    """[BETA] Apply single transformation randomly picked from a list.

123
    .. v2betastatus:: RandomChoice transform
124

125
126
127
128
129
130
131
132
    This transform does not support torchscript.

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
        p (list of floats or None, optional): probability of each transform being picked.
            If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
            (default), all transforms have the same probability.
    """
133

134
135
136
137
138
    def __init__(
        self,
        transforms: Sequence[Callable],
        p: Optional[List[float]] = None,
    ) -> None:
139
140
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
141

142
143
144
        if p is None:
            p = [1] * len(transforms)
        elif len(p) != len(transforms):
Nicolas Hug's avatar
Nicolas Hug committed
145
            raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}")
146

147
        super().__init__()
148

149
        self.transforms = transforms
150
151
        total = sum(p)
        self.p = [prob / total for prob in p]
152

153
    def forward(self, *inputs: Any) -> Any:
154
        idx = int(torch.multinomial(torch.tensor(self.p), 1))
155
156
157
158
159
        transform = self.transforms[idx]
        return transform(*inputs)


class RandomOrder(Transform):
160
161
    """[BETA] Apply a list of transformations in a random order.

162
    .. v2betastatus:: RandomOrder transform
163
164

    This transform does not support torchscript.
165
166
167

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
168
169
    """

170
171
172
    def __init__(self, transforms: Sequence[Callable]) -> None:
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
173
174
175
        super().__init__()
        self.transforms = transforms

176
    def forward(self, *inputs: Any) -> Any:
177
        needs_unpacking = len(inputs) > 1
178
179
        for idx in torch.randperm(len(self.transforms)):
            transform = self.transforms[idx]
180
181
182
            outputs = transform(*inputs)
            inputs = outputs if needs_unpacking else (outputs,)
        return outputs