base.py 1.35 KB
Newer Older
mibaumgartner's avatar
mibaumgartner committed
1
from typing import Any, Sequence
mibaumgartner's avatar
io  
mibaumgartner committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

import torch


class AbstractTransform(torch.nn.Module):
    def __init__(self, grad: bool = False, **kwargs):
        """
        Args:
            grad: enable gradient computation inside transformation
        """
        super().__init__()
        self.grad = grad

    def __call__(self, *args, **kwargs) -> Any:
        """
        Call super class with correct torch context

        Args:
            *args: forwarded positional arguments
            **kwargs: forwarded keyword arguments

        Returns:
            Any: transformed data

        """
        if self.grad:
            context = torch.enable_grad()
        else:
            context = torch.no_grad()

        with context:
mibaumgartner's avatar
mibaumgartner committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
            return super().__call__(*args, **kwargs)


class Compose(AbstractTransform):
    def __init__(self, *transforms):
        """
        Compose multiple transforms to one
        
        Args:
            transforms: transformations to compose
        """
        super().__init__(grad=False)
        if len(transforms) == 1 and isinstance(transforms[0], Sequence):
            transforms = transforms[0]

        self.transforms = torch.nn.ModuleList(list(transforms))

    def forward(self, **batch):
        """
        Augment batch
        """
        for t in self.transforms:
            batch = t(**batch)
        return batch