dg_head.py 9.32 KB
Newer Older
zhe chen's avatar
zhe chen committed
1
import numpy as np
yeshenglong1's avatar
yeshenglong1 committed
2
import torch
zhe chen's avatar
zhe chen committed
3
from mmdet.models import HEADS, build_head
yeshenglong1's avatar
yeshenglong1 committed
4
5
6
7
from mmdet.models.utils import build_transformer
from mmdet.models.utils.transformer import inverse_sigmoid

from ..augmentation.sythesis_det import NoiseSythesis
zhe chen's avatar
zhe chen committed
8
9
from .base_map_head import BaseMapHead

yeshenglong1's avatar
yeshenglong1 committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

@HEADS.register_module(force=True)
class DGHead(BaseMapHead):

    def __init__(self,
                 det_net_cfg=dict(),
                 gen_net_cfg=dict(),
                 loss_vert=dict(),
                 loss_face=dict(),
                 max_num_vertices=90,
                 top_p_gen_model=0.9,
                 sync_cls_avg_factor=True,
                 augmentation=False,
                 augmentation_kwargs=None,
                 joint_training=False,
                 **kwargs):
        super().__init__()

        # Heads
        self.det_net = build_head(det_net_cfg)
        self.gen_net = build_head(gen_net_cfg)

        self.coord_dim = self.gen_net.coord_dim

        # Loss params
        self.bg_cls_weight = 1.0
        self.sync_cls_avg_factor = sync_cls_avg_factor
        self.max_num_vertices = max_num_vertices
        self.top_p_gen_model = top_p_gen_model

        self.fp16_enabled = False

        self.augmentation = None
        if augmentation:
zhe chen's avatar
zhe chen committed
44
            augmentation_kwargs.update({'canvas_size': gen_net_cfg.canvas_size})
yeshenglong1's avatar
yeshenglong1 committed
45
            self.augmentation = NoiseSythesis(**augmentation_kwargs)
zhe chen's avatar
zhe chen committed
46

yeshenglong1's avatar
yeshenglong1 committed
47
48
49
50
51
52
        self.joint_training = joint_training

    def forward(self, batch, img_metas=None, **kwargs):
        '''
            Args:
            Returns:
zhe chen's avatar
zhe chen committed
53
                outs (Dict):
yeshenglong1's avatar
yeshenglong1 committed
54
55
56
57
58
59
60
61
62
63
64
65
        '''

        if self.training:
            return self.forward_train(batch, **kwargs)
        else:
            return self.inference(batch, **kwargs)

    def forward_train(self, batch: dict, context: dict, only_det=False, **kwargs):
        ''' we use teacher force strategy'''

        bbox_dict = self.det_net(context=context)
        outs = dict(
zhe chen's avatar
zhe chen committed
66
67
            bbox=bbox_dict,
        )
yeshenglong1's avatar
yeshenglong1 committed
68
69
70
71
72
73
74

        losses_dict, det_match_idxs, det_match_gt_idxs = \
            self.loss_det(batch, outs)

        if only_det: return outs, losses_dict

        if self.augmentation is not None:
zhe chen's avatar
zhe chen committed
75
76
            polylines, bbox_flat = \
                self.augmentation(batch['gen'], simple_aug=True)
yeshenglong1's avatar
yeshenglong1 committed
77
78
79

            if bbox_flat is None:
                bbox_flat = batch['gen']['bbox_flat']
zhe chen's avatar
zhe chen committed
80

yeshenglong1's avatar
yeshenglong1 committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            gen_input = dict(
                lines_bs_idx=batch['gen']['lines_bs_idx'],
                lines_cls=batch['gen']['lines_cls'],
                bbox_flat=bbox_flat,
                polylines=polylines,
                polyline_masks=batch['gen']['polyline_masks']
            )
        else:
            gen_input = batch['gen']

        if self.joint_training:

            # for down stream polyline
            if 'lines' in bbox_dict[-1]:
                # for fix anchor
                pred_bbox = bbox_dict[-1]['lines'].detach()
            elif 'bboxs' in bbox_dict[-1]:
                # for rpv
                pred_bbox = bbox_dict[-1]['bboxs'].detach()
            else:
                raise NotImplementedError
zhe chen's avatar
zhe chen committed
102
103

            #  changed to original gt order.
yeshenglong1's avatar
yeshenglong1 committed
104
105
106
107
            det_match_idx = det_match_idxs[-1]
            det_match_gt_idx = det_match_gt_idxs[-1]

            _bboxs = []
zhe chen's avatar
zhe chen committed
108
109
110
            for i, (match_idx, bbox) in enumerate(zip(det_match_idx, pred_bbox)):
                _bboxs.append(bbox[match_idx])
                _bboxs[-1] = _bboxs[-1][torch.argsort(det_match_gt_idx[i])]
yeshenglong1's avatar
yeshenglong1 committed
111
112

            _bboxs = torch.cat(_bboxs, dim=0)
zhe chen's avatar
zhe chen committed
113

yeshenglong1's avatar
yeshenglong1 committed
114
115
116
117
118
            # quantize the data
            _bboxs = \
                torch.round(_bboxs).type(torch.int32)

            # gen_input['bbox_flat'] = _bboxs
zhe chen's avatar
zhe chen committed
119
            remain_idx = torch.randperm(_bboxs.shape[0])[:int(_bboxs.shape[0] * 0.2)]
yeshenglong1's avatar
yeshenglong1 committed
120
121
122
            # for data efficient
            for k in gen_input.keys():
                if k == 'bbox_flat':
zhe chen's avatar
zhe chen committed
123
                    gen_input[k] = torch.cat((_bboxs, gen_input[k][remain_idx]), dim=0)
yeshenglong1's avatar
yeshenglong1 committed
124
                else:
zhe chen's avatar
zhe chen committed
125
126
127
                    gen_input[k] = torch.cat((gen_input[k], gen_input[k][remain_idx]), dim=0)

        if isinstance(context['bev_embeddings'], tuple):
yeshenglong1's avatar
yeshenglong1 committed
128
129
130
131
132
133
134
135
136
137
138
            context['bev_embeddings'] = context['bev_embeddings'][0]

        poly_dict = self.gen_net(gen_input, context=context)

        outs.update(dict(
            polylines=poly_dict,
        ))

        if self.joint_training:
            for k in batch['gen'].keys():
                batch['gen'][k] = \
zhe chen's avatar
zhe chen committed
139
                    torch.cat((batch['gen'][k], batch['gen'][k][remain_idx]), dim=0)
yeshenglong1's avatar
yeshenglong1 committed
140
141
142
143

        gen_losses_dict = \
            self.loss_gen(batch, outs)

zhe chen's avatar
zhe chen committed
144
        losses_dict.update(gen_losses_dict)
yeshenglong1's avatar
yeshenglong1 committed
145
146
147
148

        return outs, losses_dict

    def loss_det(self, gt: dict, pred: dict):
zhe chen's avatar
zhe chen committed
149

yeshenglong1's avatar
yeshenglong1 committed
150
151
152
153
154
155
156
        loss_dict = {}

        # det
        det_loss_dict, det_match_idx, det_match_gt_idx = \
            self.det_net.loss(gt['det'], pred['bbox'])

        for k, v in det_loss_dict.items():
zhe chen's avatar
zhe chen committed
157
158
            loss_dict['det_' + k] = v

yeshenglong1's avatar
yeshenglong1 committed
159
160
161
162
163
164
165
166
167
168
        return loss_dict, det_match_idx, det_match_gt_idx

    def loss_gen(self, gt: dict, pred: dict):

        loss_dict = {}

        # gen
        gen_loss_dict = self.gen_net.loss(gt['gen'], pred['polylines'])

        for k, v in gen_loss_dict.items():
zhe chen's avatar
zhe chen committed
169
            loss_dict['gen_' + k] = v
yeshenglong1's avatar
yeshenglong1 committed
170
171

        return loss_dict
zhe chen's avatar
zhe chen committed
172

yeshenglong1's avatar
yeshenglong1 committed
173
    def loss(self, gt: dict, pred: dict):
zhe chen's avatar
zhe chen committed
174

yeshenglong1's avatar
yeshenglong1 committed
175
176
177
        pass

    @torch.no_grad()
zhe chen's avatar
zhe chen committed
178
    def inference(self, batch: dict = {}, context: dict = {}, gt_condition=False, **kwargs):
yeshenglong1's avatar
yeshenglong1 committed
179
180
181
182
183
184
        '''
            num_samples_batch: number of sample per batch (batch size)
        '''
        outs = {}
        bbox_dict = self.det_net(context=context)
        bbox_dict = self.det_net.post_process(bbox_dict)
zhe chen's avatar
zhe chen committed
185

yeshenglong1's avatar
yeshenglong1 committed
186
        outs.update(bbox_dict)
zhe chen's avatar
zhe chen committed
187

yeshenglong1's avatar
yeshenglong1 committed
188
189
        if len(outs['lines_bs_idx']) == 0:
            return None
zhe chen's avatar
zhe chen committed
190
191

        if isinstance(context['bev_embeddings'], tuple):
yeshenglong1's avatar
yeshenglong1 committed
192
193
194
195
            context['bev_embeddings'] = context['bev_embeddings'][0]

        poly_dict = self.gen_net(outs,
                                 context=context,
zhe chen's avatar
zhe chen committed
196
                                 #  max_sample_length=self.max_num_vertices,
yeshenglong1's avatar
yeshenglong1 committed
197
198
199
200
201
202
203
                                 max_sample_length=64,
                                 top_p=self.top_p_gen_model,
                                 gt_condition=gt_condition)
        outs.update(poly_dict)

        return outs

zhe chen's avatar
zhe chen committed
204
    def post_process(self, preds: dict, tokens, gts: dict = None, **kwargs):
yeshenglong1's avatar
yeshenglong1 committed
205
206
207
208
209
210
211
212
        '''
            Args:
                XXX
            Outs:
               XXX
        '''
        range_size = self.gen_net.canvas_size.cpu().numpy()
        coord_dim = self.gen_net.coord_dim
zhe chen's avatar
zhe chen committed
213
214

        gen_net_name = self.gen_net.name if hasattr(self.gen_net, 'name') else 'gen'
yeshenglong1's avatar
yeshenglong1 committed
215
216
217
218
219
220
221
222
223
224

        ret_list = []
        for batch_idx in range(len(tokens)):

            ret_dict_single = {}

            # bbox
            det_gt = None
            if gts is not None:
                det_gt, rec_groundtruth = pack_groundtruth(
zhe chen's avatar
zhe chen committed
225
226
                    batch_idx, gts, tokens, range_size, gen_net_name, coord_dim=coord_dim)

yeshenglong1's avatar
yeshenglong1 committed
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
            bbox_res = {
                # 'bboxes': preds['bbox'][batch_idx].detach().cpu().numpy(),
                # 'det_gt': det_gt,
                'token': tokens[batch_idx],
                'scores': preds['scores'][batch_idx].detach().cpu().numpy(),
                'labels': preds['labels'][batch_idx].detach().cpu().numpy(),
            }
            ret_dict_single.update(bbox_res)

            # for gen results.
            batch2seq = np.nonzero(
                preds['lines_bs_idx'].cpu().numpy() == batch_idx)[0]

            ret_dict_single.update({
                'nline': len(batch2seq),
                'vectors': []
            })

            for i in batch2seq:
                pre = preds['polylines'][i].detach().cpu().numpy()
                pre_msk = preds['polyline_masks'][i].detach().cpu().numpy()
                valid_idx = np.nonzero(pre_msk)[0][:-1]

                # From [200,1] to [199,0] to (1,0)
zhe chen's avatar
zhe chen committed
251
                line = (pre[valid_idx].reshape(-1, coord_dim) - 1) / (range_size - 1)
yeshenglong1's avatar
yeshenglong1 committed
252
253

                ret_dict_single['vectors'].append(line)
zhe chen's avatar
zhe chen committed
254

yeshenglong1's avatar
yeshenglong1 committed
255
256
257
258
259
260
261
262
            # if gts is not None:
            #     ret_dict_single['groundTruth'] = rec_groundtruth

            ret_list.append(ret_dict_single)

        return ret_list


zhe chen's avatar
zhe chen committed
263
def pack_groundtruth(batch_idx, gts, tokens, range_size, gen_net_name='gen', coord_dim=2):
yeshenglong1's avatar
yeshenglong1 committed
264
265
266
267
268
269
270
271
272
273
274
275
276
    if 'keypoints' in gts['det']:
        gt_bbox = \
            gts['det']['keypoints'][batch_idx].detach().cpu().numpy()
    else:
        gt_bbox = \
            gts['det']['bbox'][batch_idx].detach().cpu().numpy()
    det_gt = {
        'labels': gts['det']['class_label'][batch_idx].detach().cpu().numpy(),
        'bboxes': gt_bbox,
    }

    batch2seq = np.nonzero(
        gts['gen']['lines_bs_idx'].cpu().numpy() == batch_idx)[0]
zhe chen's avatar
zhe chen committed
277

yeshenglong1's avatar
yeshenglong1 committed
278
279
280
281
282
283
284
285
    ret_groundtruth = {
        'token': tokens[batch_idx],
        'nline': len(batch2seq),
        'labels': gts['gen']['lines_cls'][batch2seq].detach().cpu().numpy(),
        'lines': [],
    }

    for i in batch2seq:
zhe chen's avatar
zhe chen committed
286
        gt_line = \
yeshenglong1's avatar
yeshenglong1 committed
287
288
289
290
291
292
            gts['gen']['polylines'].detach().cpu().numpy()[i]
        gt_msk = gts['gen']['polyline_masks'].detach().cpu().numpy()[i]
        if gen_net_name == 'gen_gmm':
            valid_idx = np.nonzero(gt_msk)[0]
        else:
            valid_idx = np.nonzero(gt_msk)[0][:-1]
zhe chen's avatar
zhe chen committed
293

yeshenglong1's avatar
yeshenglong1 committed
294
        # From [200,1] to [199,0] to (1,0)
zhe chen's avatar
zhe chen committed
295
        line = (gt_line[valid_idx].reshape(-1, coord_dim) - 1) / (range_size - 1)
yeshenglong1's avatar
yeshenglong1 committed
296
        ret_groundtruth['lines'].append(line)
zhe chen's avatar
zhe chen committed
297

yeshenglong1's avatar
yeshenglong1 committed
298
    return det_gt, ret_groundtruth