_container.py 5.87 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
2
3

import torch
4
5

from torch import nn
6
from torchvision import transforms as _transforms
7
from torchvision.transforms.v2 import Transform
8
9
10


class Compose(Transform):
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    """[BETA] Composes several transforms together.

    .. betastatus:: Compose transform

    This transform does not support torchscript.
    Please, see the note below.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
        >>> ])

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    """

42
    def __init__(self, transforms: Sequence[Callable]) -> None:
43
        super().__init__()
44
45
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
46
47
        self.transforms = transforms

48
    def forward(self, *inputs: Any) -> Any:
49
        sample = inputs if len(inputs) > 1 else inputs[0]
50
        for transform in self.transforms:
51
52
            sample = transform(sample)
        return sample
53

54
55
56
57
58
59
    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)

60

61
class RandomApply(Transform):
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    """[BETA] Apply randomly a list of transformations with a given probability.

    .. betastatus:: RandomApply transform

    .. note::
        In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
        transforms as shown below:

        >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
        >>>     transforms.ColorJitter(),
        >>> ]), p=0.3)
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
80
        p (float): probability of applying the list of transforms
81
82
    """

83
84
    _v1_transform_cls = _transforms.RandomApply

85
86
87
88
89
90
    def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
        super().__init__()

        if not isinstance(transforms, (Sequence, nn.ModuleList)):
            raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
        self.transforms = transforms
91
92
93
94
95

        if not (0.0 <= p <= 1.0):
            raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
        self.p = p

96
97
98
    def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
        return {"transforms": self.transforms, "p": self.p}

99
100
    def forward(self, *inputs: Any) -> Any:
        sample = inputs if len(inputs) > 1 else inputs[0]
101

102
103
        if torch.rand(1) >= self.p:
            return sample
104

105
106
107
108
109
110
111
112
113
        for transform in self.transforms:
            sample = transform(sample)
        return sample

    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)
114
115
116


class RandomChoice(Transform):
117
118
119
120
    """[BETA] Apply single transformation randomly picked from a list.

    .. betastatus:: RandomChoice transform

121
122
123
124
125
126
127
128
    This transform does not support torchscript.

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
        p (list of floats or None, optional): probability of each transform being picked.
            If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
            (default), all transforms have the same probability.
    """
129

130
131
132
133
134
    def __init__(
        self,
        transforms: Sequence[Callable],
        p: Optional[List[float]] = None,
    ) -> None:
135
136
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
137

138
139
140
        if p is None:
            p = [1] * len(transforms)
        elif len(p) != len(transforms):
141
            raise ValueError(
142
                f"The number of p doesn't match the number of transforms: " f"{len(p)} != {len(transforms)}"
143
144
            )

145
        super().__init__()
146

147
        self.transforms = transforms
148
149
        total = sum(p)
        self.p = [prob / total for prob in p]
150

151
    def forward(self, *inputs: Any) -> Any:
152
        idx = int(torch.multinomial(torch.tensor(self.p), 1))
153
154
155
156
157
        transform = self.transforms[idx]
        return transform(*inputs)


class RandomOrder(Transform):
158
159
160
161
162
    """[BETA] Apply a list of transformations in a random order.

    .. betastatus:: RandomOrder transform

    This transform does not support torchscript.
163
164
165

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
166
167
    """

168
169
170
    def __init__(self, transforms: Sequence[Callable]) -> None:
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
171
172
173
        super().__init__()
        self.transforms = transforms

174
    def forward(self, *inputs: Any) -> Any:
175
        sample = inputs if len(inputs) > 1 else inputs[0]
176
177
        for idx in torch.randperm(len(self.transforms)):
            transform = self.transforms[idx]
178
179
            sample = transform(sample)
        return sample