transforms.py 2.76 KB
Newer Older
1
2
import random

3
import numpy as np
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F


def pad_if_smaller(img, size, fill=0):
    min_size = min(img.size)
    if min_size < size:
        ow, oh = img.size
        padh = size - oh if oh < size else 0
        padw = size - ow if ow < size else 0
        img = F.pad(img, (0, 0, padw, padh), fill=fill)
    return img


19
class Compose:
20
21
22
23
24
25
26
27
28
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


29
class RandomResize:
30
31
32
33
34
35
36
37
    def __init__(self, min_size, max_size=None):
        self.min_size = min_size
        if max_size is None:
            max_size = min_size
        self.max_size = max_size

    def __call__(self, image, target):
        size = random.randint(self.min_size, self.max_size)
38
        image = F.resize(image, size, antialias=True)
39
        target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
40
41
42
        return image, target


43
class RandomHorizontalFlip:
44
45
46
47
48
49
50
51
52
53
    def __init__(self, flip_prob):
        self.flip_prob = flip_prob

    def __call__(self, image, target):
        if random.random() < self.flip_prob:
            image = F.hflip(image)
            target = F.hflip(target)
        return image, target


54
class RandomCrop:
55
56
57
58
59
60
61
62
63
64
65
66
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        image = pad_if_smaller(image, self.size)
        target = pad_if_smaller(target, self.size, fill=255)
        crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
        image = F.crop(image, *crop_params)
        target = F.crop(target, *crop_params)
        return image, target


67
class CenterCrop:
68
69
70
71
72
73
74
75
76
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        image = F.center_crop(image, self.size)
        target = F.center_crop(target, self.size)
        return image, target


77
class PILToTensor:
78
    def __call__(self, image, target):
79
        image = F.pil_to_tensor(image)
80
        target = torch.as_tensor(np.array(target), dtype=torch.int64)
81
82
83
        return image, target


84
85
class ToDtype:
    def __init__(self, dtype, scale=False):
86
        self.dtype = dtype
87
        self.scale = scale
88
89

    def __call__(self, image, target):
90
91
        if not self.scale:
            return image.to(dtype=self.dtype), target
92
93
94
95
        image = F.convert_image_dtype(image, self.dtype)
        return image, target


96
class Normalize:
97
98
99
100
101
102
103
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target