_container.py 5.86 KB
Newer Older
1
import warnings
2
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
3
4

import torch
5
6

from torch import nn
7
from torchvision import transforms as _transforms
8
from torchvision.transforms.v2 import Transform
9
10
11


class Compose(Transform):
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    """[BETA] Composes several transforms together.

    .. betastatus:: Compose transform

    This transform does not support torchscript.
    Please, see the note below.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
        >>> ])

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    """

43
    def __init__(self, transforms: Sequence[Callable]) -> None:
44
        super().__init__()
45
46
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
47
48
        self.transforms = transforms

49
    def forward(self, *inputs: Any) -> Any:
50
        sample = inputs if len(inputs) > 1 else inputs[0]
51
        for transform in self.transforms:
52
53
            sample = transform(sample)
        return sample
54

55
56
57
58
59
60
    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)

61

62
class RandomApply(Transform):
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    """[BETA] Apply randomly a list of transformations with a given probability.

    .. betastatus:: RandomApply transform

    .. note::
        In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of
        transforms as shown below:

        >>> transforms = transforms.RandomApply(torch.nn.ModuleList([
        >>>     transforms.ColorJitter(),
        >>> ]), p=0.3)
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    Args:
        transforms (sequence or torch.nn.Module): list of transformations
        p (float): probability
    """

84
85
    _v1_transform_cls = _transforms.RandomApply

86
87
88
89
90
91
    def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
        super().__init__()

        if not isinstance(transforms, (Sequence, nn.ModuleList)):
            raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
        self.transforms = transforms
92
93
94
95
96

        if not (0.0 <= p <= 1.0):
            raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
        self.p = p

97
98
99
    def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
        return {"transforms": self.transforms, "p": self.p}

100
101
    def forward(self, *inputs: Any) -> Any:
        sample = inputs if len(inputs) > 1 else inputs[0]
102

103
104
        if torch.rand(1) >= self.p:
            return sample
105

106
107
108
109
110
111
112
113
114
        for transform in self.transforms:
            sample = transform(sample)
        return sample

    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)
115
116
117


class RandomChoice(Transform):
118
119
120
121
122
123
    """[BETA] Apply single transformation randomly picked from a list.

    .. betastatus:: RandomChoice transform

    This transform does not support torchscript."""

124
125
126
127
128
129
    def __init__(
        self,
        transforms: Sequence[Callable],
        probabilities: Optional[List[float]] = None,
        p: Optional[List[float]] = None,
    ) -> None:
130
131
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
132
133
134
135
136
137
138
        if p is not None:
            warnings.warn(
                "Argument p is deprecated and will be removed in a future release. "
                "Please use probabilities argument instead."
            )
            probabilities = p

139
140
141
142
143
144
145
146
        if probabilities is None:
            probabilities = [1] * len(transforms)
        elif len(probabilities) != len(transforms):
            raise ValueError(
                f"The number of probabilities doesn't match the number of transforms: "
                f"{len(probabilities)} != {len(transforms)}"
            )

147
        super().__init__()
148

149
        self.transforms = transforms
150
        total = sum(probabilities)
151
        self.probabilities = [prob / total for prob in probabilities]
152

153
    def forward(self, *inputs: Any) -> Any:
154
        idx = int(torch.multinomial(torch.tensor(self.probabilities), 1))
155
156
157
158
159
        transform = self.transforms[idx]
        return transform(*inputs)


class RandomOrder(Transform):
160
161
162
163
164
165
166
    """[BETA] Apply a list of transformations in a random order.

    .. betastatus:: RandomOrder transform

    This transform does not support torchscript.
    """

167
168
169
    def __init__(self, transforms: Sequence[Callable]) -> None:
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
170
171
172
        super().__init__()
        self.transforms = transforms

173
    def forward(self, *inputs: Any) -> Any:
174
        sample = inputs if len(inputs) > 1 else inputs[0]
175
176
        for idx in torch.randperm(len(self.transforms)):
            transform = self.transforms[idx]
177
178
            sample = transform(sample)
        return sample