_container.py 3.8 KB
Newer Older
1
import warnings
2
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
3
4

import torch
5
6

from torch import nn
7
from torchvision import transforms as _transforms
8
from torchvision.transforms.v2 import Transform
9
10
11


class Compose(Transform):
12
    def __init__(self, transforms: Sequence[Callable]) -> None:
13
        super().__init__()
14
15
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
16
17
        self.transforms = transforms

18
    def forward(self, *inputs: Any) -> Any:
19
        sample = inputs if len(inputs) > 1 else inputs[0]
20
        for transform in self.transforms:
21
22
            sample = transform(sample)
        return sample
23

24
25
26
27
28
29
    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)

30

31
class RandomApply(Transform):
32
33
    _v1_transform_cls = _transforms.RandomApply

34
35
36
37
38
39
    def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
        super().__init__()

        if not isinstance(transforms, (Sequence, nn.ModuleList)):
            raise TypeError("Argument transforms should be a sequence of callables or a `nn.ModuleList`")
        self.transforms = transforms
40
41
42
43
44

        if not (0.0 <= p <= 1.0):
            raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
        self.p = p

45
46
47
    def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
        return {"transforms": self.transforms, "p": self.p}

48
49
    def forward(self, *inputs: Any) -> Any:
        sample = inputs if len(inputs) > 1 else inputs[0]
50

51
52
        if torch.rand(1) >= self.p:
            return sample
53

54
55
56
57
58
59
60
61
62
        for transform in self.transforms:
            sample = transform(sample)
        return sample

    def extra_repr(self) -> str:
        format_string = []
        for t in self.transforms:
            format_string.append(f"    {t}")
        return "\n".join(format_string)
63
64
65


class RandomChoice(Transform):
66
67
68
69
70
71
    def __init__(
        self,
        transforms: Sequence[Callable],
        probabilities: Optional[List[float]] = None,
        p: Optional[List[float]] = None,
    ) -> None:
72
73
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
74
75
76
77
78
79
80
        if p is not None:
            warnings.warn(
                "Argument p is deprecated and will be removed in a future release. "
                "Please use probabilities argument instead."
            )
            probabilities = p

81
82
83
84
85
86
87
88
        if probabilities is None:
            probabilities = [1] * len(transforms)
        elif len(probabilities) != len(transforms):
            raise ValueError(
                f"The number of probabilities doesn't match the number of transforms: "
                f"{len(probabilities)} != {len(transforms)}"
            )

89
        super().__init__()
90

91
        self.transforms = transforms
92
        total = sum(probabilities)
93
        self.probabilities = [prob / total for prob in probabilities]
94

95
    def forward(self, *inputs: Any) -> Any:
96
        idx = int(torch.multinomial(torch.tensor(self.probabilities), 1))
97
98
99
100
101
        transform = self.transforms[idx]
        return transform(*inputs)


class RandomOrder(Transform):
102
103
104
    def __init__(self, transforms: Sequence[Callable]) -> None:
        if not isinstance(transforms, Sequence):
            raise TypeError("Argument transforms should be a sequence of callables")
105
106
107
        super().__init__()
        self.transforms = transforms

108
    def forward(self, *inputs: Any) -> Any:
109
        sample = inputs if len(inputs) > 1 else inputs[0]
110
111
        for idx in torch.randperm(len(self.transforms)):
            transform = self.transforms[idx]
112
113
            sample = transform(sample)
        return sample