_container.py 5.91 KB
Newer Older
1
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
2
3

import torch
4
5

from torch import nn
6
from torchvision import transforms as _transforms
7
from torchvision.transforms.v2 import Transform
8
9
10


class Compose(Transform):
11
    """Composes several transforms together.
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

    This transform does not support torchscript.
    Please, see the note below.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
        >>> ])

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    """

40
    def __init__(self, transforms: Sequence[Callable]) -> None:
41
        super().__init__()
42
43
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
44
45
        elif not transforms:
            raise ValueError("Pass at least one transform")
46
47
        self.transforms = transforms

48
    def forward(self, *inputs: Any) -> Any:
49
        needs_unpacking = len(inputs) > 1
50
        for transform in self.transforms:
51
52
53
            outputs = transform(*inputs)
            inputs = outputs if needs_unpacking else (outputs,)
        return outputs
54

55
56
57
58
59
60
    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)

61

62
class RandomApply(Transform):
63
    """Apply randomly a list of transformations with a given probability.
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    .. note::
        In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
        transforms as shown below:

        >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
        >>>     transforms.ColorJitter(),
        >>> ]), p=0.3)
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
79
        p (float): probability of applying the list of transforms
80
81
    """

82
83
    _v1_transform_cls = _transforms.RandomApply

84
85
86
87
88
89
    def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
        super().__init__()

        if not isinstance(transforms, (Sequence, nn.ModuleList)):
            raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
        self.transforms = transforms
90
91
92
93
94

        if not (0.0 <= p <= 1.0):
            raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
        self.p = p

95
96
97
    def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
        return {"transforms": self.transforms, "p": self.p}

98
    def forward(self, *inputs: Any) -> Any:
99
        needs_unpacking = len(inputs) > 1
100

101
        if torch.rand(1) >= self.p:
102
            return inputs if needs_unpacking else inputs[0]
103

104
        for transform in self.transforms:
105
106
107
            outputs = transform(*inputs)
            inputs = outputs if needs_unpacking else (outputs,)
        return outputs
108
109
110
111
112
113

    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)
114
115
116


class RandomChoice(Transform):
117
    """Apply single transformation randomly picked from a list.
118

119
120
121
122
123
124
125
126
    This transform does not support torchscript.

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
        p (list of floats or None, optional): probability of each transform being picked.
            If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
            (default), all transforms have the same probability.
    """
127

128
129
130
131
132
    def __init__(
        self,
        transforms: Sequence[Callable],
        p: Optional[List[float]] = None,
    ) -> None:
133
134
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
135

136
137
138
        if p is None:
            p = [1] * len(transforms)
        elif len(p) != len(transforms):
Nicolas Hug's avatar
Nicolas Hug committed
139
            raise ValueError(f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}")
140

141
        super().__init__()
142

143
        self.transforms = transforms
144
145
        total = sum(p)
        self.p = [prob / total for prob in p]
146

147
    def forward(self, *inputs: Any) -> Any:
148
        idx = int(torch.multinomial(torch.tensor(self.p), 1))
149
150
151
152
153
        transform = self.transforms[idx]
        return transform(*inputs)


class RandomOrder(Transform):
154
    """Apply a list of transformations in a random order.
155
156

    This transform does not support torchscript.
157
158
159

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
160
161
    """

162
163
164
    def __init__(self, transforms: Sequence[Callable]) -> None:
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
165
166
167
        super().__init__()
        self.transforms = transforms

168
    def forward(self, *inputs: Any) -> Any:
169
        needs_unpacking = len(inputs) > 1
170
171
        for idx in torch.randperm(len(self.transforms)):
            transform = self.transforms[idx]
172
173
174
            outputs = transform(*inputs)
            inputs = outputs if needs_unpacking else (outputs,)
        return outputs