presets.py 3.45 KB
Newer Older
1
import torch
2
from torchvision.transforms.functional import InterpolationMode
3
4


5
6
7
8
9
10
11
12
13
14
15
16
def get_module(use_v2):
    # We need a protected import to avoid the V2 warning in case just V1 is used
    if use_v2:
        import torchvision.transforms.v2

        return torchvision.transforms.v2
    else:
        import torchvision.transforms

        return torchvision.transforms


17
class ClassificationPresetTrain:
18
19
    def __init__(
        self,
20
        *,
21
22
23
        crop_size,
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
24
        interpolation=InterpolationMode.BILINEAR,
25
26
        hflip_prob=0.5,
        auto_augment_policy=None,
Ponku's avatar
Ponku committed
27
28
        ra_magnitude=9,
        augmix_severity=3,
29
        random_erase_prob=0.0,
30
        backend="pil",
31
        use_v2=False,
32
    ):
33
34
35
        module = get_module(use_v2)

        transforms = []
36
37
        backend = backend.lower()
        if backend == "tensor":
38
            transforms.append(module.PILToTensor())
39
40
41
        elif backend != "pil":
            raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

42
        transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
43
        if hflip_prob > 0:
44
            transforms.append(module.RandomHorizontalFlip(hflip_prob))
45
        if auto_augment_policy is not None:
46
            if auto_augment_policy == "ra":
47
                transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
48
            elif auto_augment_policy == "ta_wide":
49
                transforms.append(module.TrivialAugmentWide(interpolation=interpolation))
50
            elif auto_augment_policy == "augmix":
51
                transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity))
52
            else:
53
54
                aa_policy = module.AutoAugmentPolicy(auto_augment_policy)
                transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation))
55
56

        if backend == "pil":
57
            transforms.append(module.PILToTensor())
58

59
        transforms.extend(
60
            [
61
62
                module.ConvertImageDtype(torch.float),
                module.Normalize(mean=mean, std=std),
63
64
            ]
        )
65
        if random_erase_prob > 0:
66
            transforms.append(module.RandomErasing(p=random_erase_prob))
67

68
        self.transforms = module.Compose(transforms)
69
70
71
72
73
74

    def __call__(self, img):
        return self.transforms(img)


class ClassificationPresetEval:
75
76
    def __init__(
        self,
77
        *,
78
79
80
81
82
        crop_size,
        resize_size=256,
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
        interpolation=InterpolationMode.BILINEAR,
83
        backend="pil",
84
        use_v2=False,
85
    ):
86
87
        module = get_module(use_v2)
        transforms = []
88
89
        backend = backend.lower()
        if backend == "tensor":
90
            transforms.append(module.PILToTensor())
91
        elif backend != "pil":
92
93
            raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

94
95
96
        transforms += [
            module.Resize(resize_size, interpolation=interpolation, antialias=True),
            module.CenterCrop(crop_size),
97
98
99
        ]

        if backend == "pil":
100
            transforms.append(module.PILToTensor())
101

102
103
104
        transforms += [
            module.ConvertImageDtype(torch.float),
            module.Normalize(mean=mean, std=std),
105
106
        ]

107
        self.transforms = module.Compose(transforms)
108
109
110

    def __call__(self, img):
        return self.transforms(img)