sythesis_det.py 5.5 KB
Newer Older
yeshenglong1's avatar
yeshenglong1 committed
1
2
3
4
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F


class NoiseSythesis(nn.Module):

zhe chen's avatar
zhe chen committed
8
9
10
11
12
13
14
    def __init__(self,
                 p, scale=0.01, shift_scale=(8, 5),
                 scaling_size=(0.1, 0.1), canvas_size=(200, 100),
                 bbox_type='sce',
                 poly_coord_dim=2,
                 bbox_coord_dim=2,
                 quantify=True):
yeshenglong1's avatar
yeshenglong1 committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
        super(NoiseSythesis, self).__init__()

        self.p = p
        self.scale = scale
        self.bbox_type = bbox_type
        self.quantify = quantify

        self.poly_coord_dim = poly_coord_dim
        self.bbox_coord_dim = bbox_coord_dim

        self.transforms = [self.random_shifting, self.random_scaling]
        # self.transforms = [self.random_scaling]

        self.register_buffer('canvas_size', torch.tensor(canvas_size))
        self.register_buffer('shift_scale', torch.tensor(shift_scale).float())
        self.register_buffer('scaling_size', torch.tensor(scaling_size))

    def random_scaling(self, bbox):
        '''
            bbox: B, paramter_num, 2
        '''
        device = bbox.device
        dtype = bbox.dtype
        B = bbox.shape[0]

zhe chen's avatar
zhe chen committed
40
        noise = (torch.rand(B, device=device) * 2 - 1)[:, None, None]  # [-1,1]
yeshenglong1's avatar
yeshenglong1 committed
41
42
43
44
45
46
47
        scale = self.scaling_size.to(device)
        scale = (noise * scale) + 1

        scaled_bbox = bbox * scale

        # recenterization
        coffset = scaled_bbox.mean(-2) - bbox.float().mean(-2)
zhe chen's avatar
zhe chen committed
48
        scaled_bbox = scaled_bbox - coffset[:, None]
yeshenglong1's avatar
yeshenglong1 committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62

        return scaled_bbox.round().type(dtype)

    def random_shifting(self, bbox):
        '''
            bbox: B, paramter_num, 2
        '''
        device = bbox.device
        batch_size = bbox.shape[0]

        shift_scale = self.shift_scale
        scale = (bbox.max(1)[0] - bbox.min(1)[0]) * 0.1
        scale = torch.where(scale < shift_scale, scale, shift_scale)

zhe chen's avatar
zhe chen committed
63
        noise = (torch.rand(batch_size, 2, device=device) * 2 - 1)  # [-1,1]
yeshenglong1's avatar
yeshenglong1 committed
64
65
66
        offset = (noise * scale).round().type(bbox.dtype)

        shifted_bbox = bbox + offset[:, None]
zhe chen's avatar
zhe chen committed
67

yeshenglong1's avatar
yeshenglong1 committed
68
        return shifted_bbox
zhe chen's avatar
zhe chen committed
69

yeshenglong1's avatar
yeshenglong1 committed
70
71
72
73
74
75
76
77
78
79
80
81
82
    def gaussian_noise_bbox(self, bbox):

        dtype = bbox.dtype
        batch_size = bbox.shape[0]

        scale = (self.canvas_size * self.scale)[:self.bbox_coord_dim]

        noisy_bbox = torch.normal(bbox.type(torch.float), scale)

        if self.quantify:
            noisy_bbox = noisy_bbox.round().type(dtype)
            # prevent out of bound case
            for i in range(self.bbox_coord_dim):
zhe chen's avatar
zhe chen committed
83
84
                noisy_bbox[..., i] = \
                    torch.clamp(noisy_bbox[..., 0], 1, self.canvas_size[i])
yeshenglong1's avatar
yeshenglong1 committed
85
86
        else:
            noisy_bbox = noisy_bbox.type(torch.float)
zhe chen's avatar
zhe chen committed
87

yeshenglong1's avatar
yeshenglong1 committed
88
        return noisy_bbox
zhe chen's avatar
zhe chen committed
89

yeshenglong1's avatar
yeshenglong1 committed
90
91
92
93
94
95
    def gaussian_noise_poly(self, polyline, polyline_mask):

        device = polyline.device
        batchsize = polyline.shape[0]
        scale = self.canvas_size * self.scale

zhe chen's avatar
zhe chen committed
96
97
98
99
        polyline = F.pad(polyline, (0, self.poly_coord_dim - 1))
        polyline = polyline.view(batchsize, -1, self.poly_coord_dim)
        mask = F.pad(polyline_mask[:, 1:], (0, self.poly_coord_dim))

yeshenglong1's avatar
yeshenglong1 committed
100
101
102
103
104
105
106
        noisy_polyline = torch.normal(polyline.type(torch.float), scale)

        if self.quantify:
            noisy_polyline = noisy_polyline.round().type(polyline.dtype)

            # prevent out of bound case
            for i in range(self.poly_coord_dim):
zhe chen's avatar
zhe chen committed
107
108
                noisy_polyline[..., i] = \
                    torch.clamp(noisy_polyline[..., i], 0, self.canvas_size[i])
yeshenglong1's avatar
yeshenglong1 committed
109
110
111
112

        else:
            noisy_polyline = noisy_polyline.type(torch.float)

zhe chen's avatar
zhe chen committed
113
114
        noisy_polyline = noisy_polyline.view(batchsize, -1) * mask
        noisy_polyline = noisy_polyline[:, :-(self.poly_coord_dim - 1)]
yeshenglong1's avatar
yeshenglong1 committed
115
116
117
118
119
120
121
122
123
124
125
126
127

        return noisy_polyline

    def random_apply(self, bbox):

        for t in self.transforms:

            if self.p < torch.rand(1):
                continue

            bbox = t(bbox)

        # prevent out of bound case
zhe chen's avatar
zhe chen committed
128
129
130
131
132
        bbox[..., 0] = \
            torch.clamp(bbox[..., 0], 0, self.canvas_size[0])

        bbox[..., 1] = \
            torch.clamp(bbox[..., 1], 0, self.canvas_size[1])
yeshenglong1's avatar
yeshenglong1 committed
133
134
135
136
137
138
139
140
141
142
143
144
145

        return bbox

    def simple_aug(self, batch):

        # augment bbox
        if self.bbox_type in ['sce', 'xyxy']:
            fbbox = batch['bbox_flat']
            seq_len = fbbox.shape[0]
            bbox = fbbox.view(seq_len, -1, 2)
            bbox = self.gaussian_noise_bbox(bbox)
            fbbox_aug = bbox.view(seq_len, -1)

zhe chen's avatar
zhe chen committed
146
147
            aug_mask = torch.rand(fbbox.shape, device=fbbox.device)
            fbbox = torch.where(aug_mask < self.p, fbbox_aug, fbbox)
yeshenglong1's avatar
yeshenglong1 committed
148
149
150
151
152
153
154
155
156
        elif self.bbox_type == 'rxyxy':
            fbbox = self.rbbox_aug(batch)
        elif self.bbox_type == 'convex_hull':
            fbbox = self.convex_hull_aug(batch)

        # augment
        polyline = batch['polylines']
        polyline_mask = batch['polyline_masks']
        polyline_aug = self.gaussian_noise_poly(polyline, polyline_mask)
zhe chen's avatar
zhe chen committed
157
158
159

        aug_mask = torch.rand(polyline.shape, device=polyline.device)
        polyline = torch.where(aug_mask < self.p, polyline_aug, polyline)
yeshenglong1's avatar
yeshenglong1 committed
160
161
162
163

        return polyline, fbbox

    def rbbox_aug(self, batch):
zhe chen's avatar
zhe chen committed
164

yeshenglong1's avatar
yeshenglong1 committed
165
        return None
zhe chen's avatar
zhe chen committed
166
167
168

    def convex_hull_aug(self, batch):

yeshenglong1's avatar
yeshenglong1 committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        return None

    def __call__(self, batch, simple_aug=False):

        if simple_aug:

            return self.simple_aug(batch)

        else:
            fbbox = batch['bbox_flat']
            seq_len = fbbox.shape[0]
            bbox = fbbox.view(seq_len, -1, self.bbox_coord_dim)

            aug_bbox = self.random_apply(bbox)

            aug_bbox_flat = aug_bbox.view(seq_len, -1)

        return aug_bbox_flat