_container.py 5.84 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
2
3

import torch
4
5

from torch import nn
6
from torchvision import transforms as _transforms
7
from torchvision.transforms.v2 import Transform
8
9
10


class Compose(Transform):
11
12
    """[BETA] Composes several transforms together.

13
    .. v2betastatus:: Compose transform
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

    This transform does not support torchscript.
    Please, see the note below.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
        >>> ])

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    """

42
    def __init__(self, transforms: Sequence[Callable]) -> None:
43
        super().__init__()
44
45
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
46
47
        self.transforms = transforms

48
    def forward(self, *inputs: Any) -> Any:
49
        sample = inputs if len(inputs) > 1 else inputs[0]
50
        for transform in self.transforms:
51
52
            sample = transform(sample)
        return sample
53

54
55
56
57
58
59
    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)

60

61
class RandomApply(Transform):
62
63
    """[BETA] Apply randomly a list of transformations with a given probability.

64
    .. v2betastatus:: RandomApply transform
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    .. note::
        In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
        transforms as shown below:

        >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
        >>>     transforms.ColorJitter(),
        >>> ]), p=0.3)
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
80
        p (float): probability of applying the list of transforms
81
82
    """

83
84
    _v1_transform_cls = _transforms.RandomApply

85
86
87
88
89
90
    def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
        super().__init__()

        if not isinstance(transforms, (Sequence, nn.ModuleList)):
            raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
        self.transforms = transforms
91
92
93
94
95

        if not (0.0 <= p <= 1.0):
            raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
        self.p = p

96
97
98
    def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
        return {"transforms": self.transforms, "p": self.p}

99
100
    def forward(self, *inputs: Any) -> Any:
        sample = inputs if len(inputs) > 1 else inputs[0]
101

102
103
        if torch.rand(1) >= self.p:
            return sample
104

105
106
107
108
109
110
111
112
113
        for transform in self.transforms:
            sample = transform(sample)
        return sample

    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)
114
115
116


class RandomChoice(Transform):
117
118
    """[BETA] Apply single transformation randomly picked from a list.

119
    .. v2betastatus:: RandomChoice transform
120

121
122
123
124
125
126
127
128
    This transform does not support torchscript.

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
        p (list of floats or None, optional): probability of each transform being picked.
            If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
            (default), all transforms have the same probability.
    """
129

130
131
132
133
134
    def __init__(
        self,
        transforms: Sequence[Callable],
        p: Optional[List[float]] = None,
    ) -> None:
135
136
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
137

138
139
140
        if p is None:
            p = [1] * len(transforms)
        elif len(p) != len(transforms):
Nicolas Hug's avatar
Nicolas Hug committed
141
            raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}")
142

143
        super().__init__()
144

145
        self.transforms = transforms
146
147
        total = sum(p)
        self.p = [prob / total for prob in p]
148

149
    def forward(self, *inputs: Any) -> Any:
150
        idx = int(torch.multinomial(torch.tensor(self.p), 1))
151
152
153
154
155
        transform = self.transforms[idx]
        return transform(*inputs)


class RandomOrder(Transform):
156
157
    """[BETA] Apply a list of transformations in a random order.

158
    .. v2betastatus:: RandomOrder transform
159
160

    This transform does not support torchscript.
161
162
163

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
164
165
    """

166
167
168
    def __init__(self, transforms: Sequence[Callable]) -> None:
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
169
170
171
        super().__init__()
        self.transforms = transforms

172
    def forward(self, *inputs: Any) -> Any:
173
        sample = inputs if len(inputs) > 1 else inputs[0]
174
175
        for idx in torch.randperm(len(self.transforms)):
            transform = self.transforms[idx]
176
177
            sample = transform(sample)
        return sample