_container.py 5.98 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
2
3

import torch
4
5

from torch import nn
6
from torchvision import transforms as _transforms
7
from torchvision.transforms.v2 import Transform
8
9
10


class Compose(Transform):
11
12
    """[BETA] Composes several transforms together.

13
    .. v2betastatus:: Compose transform
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

    This transform does not support torchscript.
    Please, see the note below.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
        >>> ])

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    """

42
    def __init__(self, transforms: Sequence[Callable]) -> None:
43
        super().__init__()
44
45
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
46
47
        elif not transforms:
            raise ValueError("Pass at least one transform")
48
49
        self.transforms = transforms

50
    def forward(self, *inputs: Any) -> Any:
51
        needs_unpacking = len(inputs) > 1
52
        for transform in self.transforms:
53
54
55
            outputs = transform(*inputs)
            inputs = outputs if needs_unpacking else (outputs,)
        return outputs
56

57
58
59
60
61
62
    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)

63

64
class RandomApply(Transform):
65
66
    """[BETA] Apply randomly a list of transformations with a given probability.

67
    .. v2betastatus:: RandomApply transform
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

    .. note::
        In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
        transforms as shown below:

        >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
        >>>     transforms.ColorJitter(),
        >>> ]), p=0.3)
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
83
        p (float): probability of applying the list of transforms
84
85
    """

86
87
    _v1_transform_cls = _transforms.RandomApply

88
89
90
91
92
93
    def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
        super().__init__()

        if not isinstance(transforms, (Sequence, nn.ModuleList)):
            raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
        self.transforms = transforms
94
95
96
97
98

        if not (0.0 <= p <= 1.0):
            raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
        self.p = p

99
100
101
    def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
        return {"transforms": self.transforms, "p": self.p}

102
103
    def forward(self, *inputs: Any) -> Any:
        sample = inputs if len(inputs) > 1 else inputs[0]
104

105
106
        if torch.rand(1) >= self.p:
            return sample
107

108
109
110
111
112
113
114
115
116
        for transform in self.transforms:
            sample = transform(sample)
        return sample

    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)
117
118
119


class RandomChoice(Transform):
120
121
    """[BETA] Apply single transformation randomly picked from a list.

122
    .. v2betastatus:: RandomChoice transform
123

124
125
126
127
128
129
130
131
    This transform does not support torchscript.

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
        p (list of floats or None, optional): probability of each transform being picked.
            If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
            (default), all transforms have the same probability.
    """
132

133
134
135
136
137
    def __init__(
        self,
        transforms: Sequence[Callable],
        p: Optional[List[float]] = None,
    ) -> None:
138
139
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
140

141
142
143
        if p is None:
            p = [1] * len(transforms)
        elif len(p) != len(transforms):
Nicolas Hug's avatar
Nicolas Hug committed
144
            raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}")
145

146
        super().__init__()
147

148
        self.transforms = transforms
149
150
        total = sum(p)
        self.p = [prob / total for prob in p]
151

152
    def forward(self, *inputs: Any) -> Any:
153
        idx = int(torch.multinomial(torch.tensor(self.p), 1))
154
155
156
157
158
        transform = self.transforms[idx]
        return transform(*inputs)


class RandomOrder(Transform):
159
160
    """[BETA] Apply a list of transformations in a random order.

161
    .. v2betastatus:: RandomOrder transform
162
163

    This transform does not support torchscript.
164
165
166

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
167
168
    """

169
170
171
    def __init__(self, transforms: Sequence[Callable]) -> None:
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
172
173
174
        super().__init__()
        self.transforms = transforms

175
    def forward(self, *inputs: Any) -> Any:
176
        sample = inputs if len(inputs) > 1 else inputs[0]
177
178
        for idx in torch.randperm(len(self.transforms)):
            transform = self.transforms[idx]
179
180
            sample = transform(sample)
        return sample