transforms.py 3.66 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
import mmcv
import numpy as np
import torch

pangjm's avatar
pangjm committed
5
from mmdet.core.mask_ops import segms
Kai Chen's avatar
Kai Chen committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

__all__ = [
    'ImageTransform', 'BboxTransform', 'PolyMaskTransform', 'Numpy2Tensor'
]


class ImageTransform(object):
    """Preprocess an image
    1. rescale the image to expected size
    2. normalize the image
    3. flip the image (if needed)
    4. pad the image (if needed)
    5. transpose to (c, h, w)
    """

    def __init__(self,
                 mean=(0, 0, 0),
                 std=(1, 1, 1),
                 to_rgb=True,
                 size_divisor=None):
        self.mean = np.array(mean, dtype=np.float32)
        self.std = np.array(std, dtype=np.float32)
        self.to_rgb = to_rgb
        self.size_divisor = size_divisor

    def __call__(self, img, scale, flip=False):
Kai Chen's avatar
Kai Chen committed
32
        img, scale_factor = mmcv.imrescale(img, scale, return_scale=True)
Kai Chen's avatar
Kai Chen committed
33
34
35
36
37
38
        img_shape = img.shape
        img = mmcv.imnorm(img, self.mean, self.std, self.to_rgb)
        if flip:
            img = mmcv.imflip(img)
        if self.size_divisor is not None:
            img = mmcv.impad_to_multiple(img, self.size_divisor)
Kai Chen's avatar
Kai Chen committed
39
40
41
            pad_shape = img.shape
        else:
            pad_shape = img_shape
Kai Chen's avatar
Kai Chen committed
42
        img = img.transpose(2, 0, 1)
Kai Chen's avatar
Kai Chen committed
43
        return img, img_shape, pad_shape, scale_factor
Kai Chen's avatar
Kai Chen committed
44
45


Kai Chen's avatar
Kai Chen committed
46
47
48
49
50
51
52
53
54
55
56
57
58
def bbox_flip(bboxes, img_shape):
    """Flip bboxes horizontally.

    Args:
        bboxes(ndarray): shape (..., 4*k)
        img_shape(tuple): (height, width)
    """
    assert bboxes.shape[-1] % 4 == 0
    w = img_shape[1]
    flipped = bboxes.copy()
    flipped[..., 0::4] = w - bboxes[..., 2::4] - 1
    flipped[..., 2::4] = w - bboxes[..., 0::4] - 1
    return flipped
Kai Chen's avatar
Kai Chen committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73


class BboxTransform(object):
    """Preprocess gt bboxes
    1. rescale bboxes according to image size
    2. flip bboxes (if needed)
    3. pad the first dimension to `max_num_gts`
    """

    def __init__(self, max_num_gts=None):
        self.max_num_gts = max_num_gts

    def __call__(self, bboxes, img_shape, scale_factor, flip=False):
        gt_bboxes = bboxes * scale_factor
        if flip:
Kai Chen's avatar
Kai Chen committed
74
            gt_bboxes = bbox_flip(gt_bboxes, img_shape)
pangjm's avatar
pangjm committed
75
76
        gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1])
        gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0])
Kai Chen's avatar
Kai Chen committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        if self.max_num_gts is None:
            return gt_bboxes
        else:
            num_gts = gt_bboxes.shape[0]
            padded_bboxes = np.zeros((self.max_num_gts, 4), dtype=np.float32)
            padded_bboxes[:num_gts, :] = gt_bboxes
            return padded_bboxes


class PolyMaskTransform(object):

    def __init__(self):
        pass

    def __call__(self, gt_mask_polys, gt_poly_lens, img_h, img_w, flip=False):
        """
        Args:
            gt_mask_polys(list): a list of masks, each mask is a list of polys,
                each poly is a list of numbers
            gt_poly_lens(list): a list of int, indicating the size of each poly
        """
        if flip:
            gt_mask_polys = segms.flip_segms(gt_mask_polys, img_h, img_w)
        num_polys_per_mask = np.array(
            [len(mask_polys) for mask_polys in gt_mask_polys], dtype=np.int64)
        gt_poly_lens = np.array(gt_poly_lens, dtype=np.int64)
        gt_mask_polys = [
            np.concatenate(mask_polys).astype(np.float32)
            for mask_polys in gt_mask_polys
        ]
        gt_mask_polys = np.concatenate(gt_mask_polys)
        return gt_mask_polys, gt_poly_lens, num_polys_per_mask


class Numpy2Tensor(object):

    def __init__(self):
        pass

    def __call__(self, *args):
        if len(args) == 1:
            return torch.from_numpy(args[0])
        else:
pangjm's avatar
pangjm committed
120
            return tuple([torch.from_numpy(np.array(array)) for array in args])