anchor_head_multi.py 16.8 KB
Newer Older
Gus-Guo's avatar
Gus-Guo committed
1
2
3
import numpy as np
import torch.nn as nn
from .anchor_head_template import AnchorHeadTemplate
4
from ..backbones_2d import BaseBEVBackbone
Gus-Guo's avatar
Gus-Guo committed
5
6
import torch

7

8
class SingleHead(BaseBEVBackbone):
9
10
11
    def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, rpn_head_cfg=None,
                 head_label_indices=None, separate_reg_config=None):
        super().__init__(rpn_head_cfg, input_channels)
Gus-Guo's avatar
Gus-Guo committed
12
13
14
15
16

        self.num_anchors_per_location = num_anchors_per_location
        self.num_class = num_class
        self.code_size = code_size
        self.model_cfg = model_cfg
17
        self.separate_reg_config = separate_reg_config
18
        self.register_buffer('head_label_indices', head_label_indices)
Gus-Guo's avatar
Gus-Guo committed
19

20
21
22
23
        if self.separate_reg_config is not None:
            code_size_cnt = 0
            self.conv_box = nn.ModuleDict()
            self.conv_box_names = []
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
            num_middle_conv = self.separate_reg_config.NUM_MIDDLE_CONV
            num_middle_filter = self.separate_reg_config.NUM_MIDDLE_FILTER
            conv_cls_list = []
            c_in = input_channels
            for k in range(num_middle_conv):
                conv_cls_list.extend([
                    nn.Conv2d(
                        c_in, num_middle_filter,
                        kernel_size=3, stride=1, padding=1, bias=False
                    ),
                    nn.BatchNorm2d(num_middle_filter),
                    nn.ReLU()
                ])
                c_in = num_middle_filter
            conv_cls_list.append(nn.Conv2d(
                c_in, self.num_anchors_per_location * self.num_class,
                kernel_size=3, stride=1, padding=1
            ))
            self.conv_cls = nn.Sequential(*conv_cls_list)

            for reg_config in self.separate_reg_config.REG_LIST:
45
                reg_name, reg_channel = reg_config.split(':')
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
                reg_channel = int(reg_channel)
                cur_conv_list = []
                c_in = input_channels
                for k in range(num_middle_conv):
                    cur_conv_list.extend([
                        nn.Conv2d(
                            c_in, num_middle_filter,
                            kernel_size=3, stride=1, padding=1, bias=False
                        ),
                        nn.BatchNorm2d(num_middle_filter),
                        nn.ReLU()
                    ])
                    c_in = num_middle_filter

                cur_conv_list.append(nn.Conv2d(
                    c_in, self.num_anchors_per_location * int(reg_channel),
62
                    kernel_size=3, stride=1, padding=1, bias=True
63
                ))
64
                code_size_cnt += reg_channel
65
                self.conv_box[f'conv_{reg_name}'] = nn.Sequential(*cur_conv_list)
66
67
                self.conv_box_names.append(f'conv_{reg_name}')

68
69
70
71
72
73
            for m in self.conv_box.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)

74
75
            assert code_size_cnt == code_size, f'Code size does not match: {code_size_cnt}:{code_size}'
        else:
76
77
78
79
            self.conv_cls = nn.Conv2d(
                input_channels, self.num_anchors_per_location * self.num_class,
                kernel_size=1
            )
80
81
82
83
            self.conv_box = nn.Conv2d(
                input_channels, self.num_anchors_per_location * self.code_size,
                kernel_size=1
            )
Gus-Guo's avatar
Gus-Guo committed
84
85
86
87
88
89
90
91
92

        if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', None) is not None:
            self.conv_dir_cls = nn.Conv2d(
                input_channels,
                self.num_anchors_per_location * self.model_cfg.NUM_DIR_BINS,
                kernel_size=1
            )
        else:
            self.conv_dir_cls = None
93
        self.use_multihead = self.model_cfg.get('USE_MULTIHEAD', False)
Gus-Guo's avatar
Gus-Guo committed
94
95
96
97
        self.init_weights()

    def init_weights(self):
        pi = 0.01
98
99
100
101
        if isinstance(self.conv_cls, nn.Conv2d):
            nn.init.constant_(self.conv_cls.bias, -np.log((1 - pi) / pi))
        else:
            nn.init.constant_(self.conv_cls[-1].bias, -np.log((1 - pi) / pi))
Gus-Guo's avatar
Gus-Guo committed
102
103
104

    def forward(self, spatial_features_2d):
        ret_dict = {}
105
        spatial_features_2d = super().forward({'spatial_features': spatial_features_2d})['spatial_features_2d']
Gus-Guo's avatar
Gus-Guo committed
106
107

        cls_preds = self.conv_cls(spatial_features_2d)
108

109
        if self.separate_reg_config is None:
110
111
112
113
            box_preds = self.conv_box(spatial_features_2d)
        else:
            box_preds_list = []
            for reg_name in self.conv_box_names:
114
                box_preds_list.append(self.conv_box[reg_name](spatial_features_2d))
115
            box_preds = torch.cat(box_preds_list, dim=1)
Gus-Guo's avatar
Gus-Guo committed
116
117
118
119
120
121
122
123
124
125
126
127

        if not self.use_multihead:
            box_preds = box_preds.permute(0, 2, 3, 1).contiguous()
            cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous()
        else:
            H, W = box_preds.shape[2:]
            batch_size = box_preds.shape[0]
            box_preds = box_preds.view(-1, self.num_anchors_per_location,
                                       self.code_size, H, W).permute(0, 1, 3, 4, 2).contiguous()
            cls_preds = cls_preds.view(-1, self.num_anchors_per_location,
                                       self.num_class, H, W).permute(0, 1, 3, 4, 2).contiguous()
            box_preds = box_preds.view(batch_size, -1, self.code_size)
128
            cls_preds = cls_preds.view(batch_size, -1, self.num_class)
129

Gus-Guo's avatar
Gus-Guo committed
130
131
132
133
        if self.conv_dir_cls is not None:
            dir_cls_preds = self.conv_dir_cls(spatial_features_2d)
            if self.use_multihead:
                dir_cls_preds = dir_cls_preds.view(
134
135
                    -1, self.num_anchors_per_location, self.model_cfg.NUM_DIR_BINS, H, W).permute(0, 1, 3, 4,
                                                                                                  2).contiguous()
Gus-Guo's avatar
Gus-Guo committed
136
137
138
                dir_cls_preds = dir_cls_preds.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS)
            else:
                dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
139

Gus-Guo's avatar
Gus-Guo committed
140
141
142
143
144
145
146
147
148
        else:
            dir_cls_preds = None

        ret_dict['cls_preds'] = cls_preds
        ret_dict['box_preds'] = box_preds
        ret_dict['dir_cls_preds'] = dir_cls_preds

        return ret_dict

149

Gus-Guo's avatar
Gus-Guo committed
150
class AnchorHeadMulti(AnchorHeadTemplate):
151
152
    def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
                 predict_boxes_when_training=True):
Gus-Guo's avatar
Gus-Guo committed
153
        super().__init__(
154
155
            model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size,
            point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training
Gus-Guo's avatar
Gus-Guo committed
156
157
        )
        self.model_cfg = model_cfg
158
159
160
161
162
        self.separate_multihead = self.model_cfg.get('SEPARATE_MULTIHEAD', False)

        if self.model_cfg.get('SHARED_CONV_NUM_FILTER', None) is not None:
            shared_conv_num_filter = self.model_cfg.SHARED_CONV_NUM_FILTER
            self.shared_conv = nn.Sequential(
163
164
165
166
                nn.Conv2d(input_channels, shared_conv_num_filter, 3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(shared_conv_num_filter, eps=1e-3, momentum=0.01),
                nn.ReLU(),
            )
167
168
169
170
        else:
            self.shared_conv = None
            shared_conv_num_filter = input_channels
        self.rpn_heads = None
171
        self.make_multihead(shared_conv_num_filter)
Gus-Guo's avatar
Gus-Guo committed
172
173
174
175
176
177

    def make_multihead(self, input_channels):
        rpn_head_cfgs = self.model_cfg.RPN_HEAD_CFGS
        rpn_heads = []
        class_names = []
        for rpn_head_cfg in rpn_head_cfgs:
178
            class_names.extend(rpn_head_cfg['HEAD_CLS_NAME'])
179

Gus-Guo's avatar
Gus-Guo committed
180
        for rpn_head_cfg in rpn_head_cfgs:
181
182
            num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)]
                                            for head_cls in rpn_head_cfg['HEAD_CLS_NAME']])
183
184
185
186
            head_label_indices = torch.from_numpy(np.array([
                self.class_names.index(cur_name) + 1 for cur_name in rpn_head_cfg['HEAD_CLS_NAME']
            ]))

187
188
189
            rpn_head = SingleHead(
                self.model_cfg, input_channels,
                len(rpn_head_cfg['HEAD_CLS_NAME']) if self.separate_multihead else self.num_class,
190
                num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg,
191
192
                head_label_indices=head_label_indices,
                separate_reg_config=self.model_cfg.get('SEPARATE_REG_CONFIG', None)
193
            )
Gus-Guo's avatar
Gus-Guo committed
194
195
196
197
198
            rpn_heads.append(rpn_head)
        self.rpn_heads = nn.ModuleList(rpn_heads)

    def forward(self, data_dict):
        spatial_features_2d = data_dict['spatial_features_2d']
199
200
        if self.shared_conv is not None:
            spatial_features_2d = self.shared_conv(spatial_features_2d)
Gus-Guo's avatar
Gus-Guo committed
201
202
203
204

        ret_dicts = []
        for rpn_head in self.rpn_heads:
            ret_dicts.append(rpn_head(spatial_features_2d))
205

206
207
        cls_preds = [ret_dict['cls_preds'] for ret_dict in ret_dicts]
        box_preds = [ret_dict['box_preds'] for ret_dict in ret_dicts]
Gus-Guo's avatar
Gus-Guo committed
208
        ret = {
209
210
            'cls_preds': cls_preds if self.separate_multihead else torch.cat(cls_preds, dim=1),
            'box_preds': box_preds if self.separate_multihead else torch.cat(box_preds, dim=1),
Gus-Guo's avatar
Gus-Guo committed
211
        }
212

Gus-Guo's avatar
Gus-Guo committed
213
        if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', False):
214
            dir_cls_preds = [ret_dict['dir_cls_preds'] for ret_dict in ret_dicts]
215
            ret['dir_cls_preds'] = dir_cls_preds if self.separate_multihead else torch.cat(dir_cls_preds, dim=1)
216

Gus-Guo's avatar
Gus-Guo committed
217
        self.forward_ret_dict.update(ret)
218

Gus-Guo's avatar
Gus-Guo committed
219
220
221
222
223
224
225
226
        if self.training:
            targets_dict = self.assign_targets(
                gt_boxes=data_dict['gt_boxes']
            )
            self.forward_ret_dict.update(targets_dict)
        else:
            batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
                batch_size=data_dict['batch_size'],
227
                cls_preds=ret['cls_preds'], box_preds=ret['box_preds'], dir_cls_preds=ret['dir_cls_preds']
Gus-Guo's avatar
Gus-Guo committed
228
            )
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

            if isinstance(batch_cls_preds, list):
                all_pred_labels = []
                all_cls_preds = []
                for idx, cls_pred in enumerate(batch_cls_preds):
                    pred_score, pred_head_label = torch.max(cls_pred, dim=-1)
                    pred_label = self.rpn_heads[idx].head_label_indices[pred_head_label]

                    all_pred_labels.append(pred_label)
                    all_cls_preds.append(pred_score[:, :, None])

                batch_cls_preds = torch.cat(all_cls_preds, dim=1)
                batch_pred_labels = torch.cat(all_pred_labels, dim=1)
                data_dict['batch_pred_labels'] = batch_pred_labels
                data_dict['has_class_labels'] = True

Gus-Guo's avatar
Gus-Guo committed
245
246
247
248
249
            data_dict['batch_cls_preds'] = batch_cls_preds
            data_dict['batch_box_preds'] = batch_box_preds
            data_dict['cls_preds_normalized'] = False

        return data_dict
250
251

    def get_cls_layer_loss(self):
252
253
254
255
256
257
258
        loss_weights = self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS
        if 'pos_cls_weight' in loss_weights:
            pos_cls_weight = loss_weights['pos_cls_weight']
            neg_cls_weight = loss_weights['neg_cls_weight']
        else:
            pos_cls_weight = neg_cls_weight = 1.0

259
260
261
262
263
264
265
266
        cls_preds = self.forward_ret_dict['cls_preds']
        box_cls_labels = self.forward_ret_dict['box_cls_labels']
        if not isinstance(cls_preds, list):
            cls_preds = [cls_preds]
        batch_size = int(cls_preds[0].shape[0])
        cared = box_cls_labels >= 0  # [N, num_anchors]
        positives = box_cls_labels > 0
        negatives = box_cls_labels == 0
267
268
269
270
        negative_cls_weights = negatives * 1.0 * neg_cls_weight

        cls_weights = (negative_cls_weights + pos_cls_weight * positives).float()

271
272
273
        reg_weights = positives.float()
        if self.num_class == 1:
            # class agnostic
274
            box_cls_labels[positives] = 1
275
        pos_normalizer = positives.sum(1, keepdim=True).float()
276

277
278
279
        reg_weights /= torch.clamp(pos_normalizer, min=1.0)
        cls_weights /= torch.clamp(pos_normalizer, min=1.0)
        cls_targets = box_cls_labels * cared.type_as(box_cls_labels)
280
281
282
283
284
285
        one_hot_targets = torch.zeros(
            *list(cls_targets.shape), self.num_class + 1, dtype=cls_preds[0].dtype, device=cls_targets.device
        )
        one_hot_targets.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0)
        one_hot_targets = one_hot_targets[..., 1:]
        start_idx = c_idx = 0
286
        cls_losses = 0
287
288
289
290
291

        for idx, cls_pred in enumerate(cls_preds):
            cur_num_class = self.rpn_heads[idx].num_class
            cls_pred = cls_pred.view(batch_size, -1, cur_num_class)
            if self.separate_multihead:
292
293
                one_hot_target = one_hot_targets[:, start_idx:start_idx + cls_pred.shape[1],
                                 c_idx:c_idx + cur_num_class]
294
295
                c_idx += cur_num_class
            else:
296
297
                one_hot_target = one_hot_targets[:, start_idx:start_idx + cls_pred.shape[1]]
            cls_weight = cls_weights[:, start_idx:start_idx + cls_pred.shape[1]]
298
299
            cls_loss_src = self.cls_loss_func(cls_pred, one_hot_target, weights=cls_weight)  # [N, M]
            cls_loss = cls_loss_src.sum() / batch_size
300
            cls_loss = cls_loss * loss_weights['cls_weight']
301
302
            cls_losses += cls_loss
            start_idx += cls_pred.shape[1]
303
        assert start_idx == one_hot_targets.shape[1]
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        tb_dict = {
            'rpn_loss_cls': cls_losses.item()
        }
        return cls_losses, tb_dict

    def get_box_reg_layer_loss(self):
        box_preds = self.forward_ret_dict['box_preds']
        box_dir_cls_preds = self.forward_ret_dict.get('dir_cls_preds', None)
        box_reg_targets = self.forward_ret_dict['box_reg_targets']
        box_cls_labels = self.forward_ret_dict['box_cls_labels']

        positives = box_cls_labels > 0
        reg_weights = positives.float()
        pos_normalizer = positives.sum(1, keepdim=True).float()
        reg_weights /= torch.clamp(pos_normalizer, min=1.0)

        if not isinstance(box_preds, list):
            box_preds = [box_preds]
        batch_size = int(box_preds[0].shape[0])

        if isinstance(self.anchors, list):
            if self.use_multihead:
                anchors = torch.cat(
                    [anchor.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchor.shape[-1]) for anchor in
                     self.anchors], dim=0)
            else:
                anchors = torch.cat(self.anchors, dim=-3)
        else:
            anchors = self.anchors
        anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1)

        start_idx = 0
        box_losses = 0
        tb_dict = {}
        for idx, box_pred in enumerate(box_preds):
            box_pred = box_pred.view(batch_size, -1,
340
341
342
343
                                     box_pred.shape[-1] // self.num_anchors_per_location if not self.use_multihead else
                                     box_pred.shape[-1])
            box_reg_target = box_reg_targets[:, start_idx:start_idx + box_pred.shape[1]]
            reg_weight = reg_weights[:, start_idx:start_idx + box_pred.shape[1]]
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
            # sin(a - b) = sinacosb-cosasinb
            box_pred_sin, reg_target_sin = self.add_sin_difference(box_pred, box_reg_target)
            loc_loss_src = self.reg_loss_func(box_pred_sin, reg_target_sin, weights=reg_weight)  # [N, M]
            loc_loss = loc_loss_src.sum() / batch_size

            loc_loss = loc_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['loc_weight']
            box_losses += loc_loss
            tb_dict['rpn_loss_loc'] = tb_dict.get('rpn_loss_loc', 0) + loc_loss

            if box_dir_cls_preds is not None:
                if not isinstance(box_dir_cls_preds, list):
                    box_dir_cls_preds = [box_dir_cls_preds]
                dir_targets = self.get_direction_target(
                    anchors, box_reg_targets,
                    dir_offset=self.model_cfg.DIR_OFFSET,
                    num_bins=self.model_cfg.NUM_DIR_BINS
                )
                box_dir_cls_pred = box_dir_cls_preds[idx]
                dir_logit = box_dir_cls_pred.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS)
                weights = positives.type_as(dir_logit)
                weights /= torch.clamp(weights.sum(-1, keepdim=True), min=1.0)
365
366
367

                weight = weights[:, start_idx:start_idx + box_pred.shape[1]]
                dir_target = dir_targets[:, start_idx:start_idx + box_pred.shape[1]]
368
369
370
371
372
373
374
                dir_loss = self.dir_loss_func(dir_logit, dir_target, weights=weight)
                dir_loss = dir_loss.sum() / batch_size
                dir_loss = dir_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['dir_weight']
                box_losses += dir_loss
                tb_dict['rpn_loss_dir'] = tb_dict.get('rpn_loss_dir', 0) + dir_loss.item()
            start_idx += box_pred.shape[1]
        return box_losses, tb_dict