anchor_head_multi.py 6.22 KB
Newer Older
Gus-Guo's avatar
Gus-Guo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import numpy as np
import torch.nn as nn
from .anchor_head_template import AnchorHeadTemplate
import torch

class SingleHead(nn.Module):
    def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, encode_conv_cfg=None):
        super(SingleHead, self).__init__()
        if encode_conv_cfg is not None:
            stride = encode_conv_cfg['stride']
            layer_num = encode_conv_cfg['layer_num']
            num_filters = input_channels
            encode_conv = []
            encode_conv.append(nn.Conv2d(num_filters, num_filters, kernel_size=1, stride=stride, bias=False))
            for i in range(layer_num-1):
                encode_conv.append(nn.Conv2d(num_filters, num_filters, 1, bias=False))
                encode_conv.append(nn.BatchNorm2d(num_filters))
                encode_conv.append(nn.ReLU(inplace=True))
            self.encode_conv = nn.Sequential(*encode_conv)
        else:
            self.encode_conv = None

        self.num_anchors_per_location = num_anchors_per_location
        self.num_class = num_class
        self.code_size = code_size
        self.model_cfg = model_cfg

        self.conv_cls = nn.Conv2d(
            input_channels, self.num_anchors_per_location * self.num_class,
            kernel_size=1
        )
        self.conv_box = nn.Conv2d(
            input_channels, self.num_anchors_per_location * self.code_size,
            kernel_size=1
        )

        if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', None) is not None:
            self.conv_dir_cls = nn.Conv2d(
                input_channels,
                self.num_anchors_per_location * self.model_cfg.NUM_DIR_BINS,
                kernel_size=1
            )
        else:
            self.conv_dir_cls = None
        self.use_multihead = self.model_cfg.get('USE_MULTI_HEAD', False)
        self.init_weights()

    def init_weights(self):
        pi = 0.01
        nn.init.constant_(self.conv_cls.bias, -np.log((1 - pi) / pi))

    def forward(self, spatial_features_2d):
        ret_dict = {}

        if self.encode_conv is not None:
            spatial_features_2d = self.encode_conv(spatial_features_2d)

        cls_preds = self.conv_cls(spatial_features_2d)
        box_preds = self.conv_box(spatial_features_2d)

        if not self.use_multihead:
            box_preds = box_preds.permute(0, 2, 3, 1).contiguous()
            cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous()
        else:
            H, W = box_preds.shape[2:]
            batch_size = box_preds.shape[0]
            box_preds = box_preds.view(-1, self.num_anchors_per_location,
                                       self.code_size, H, W).permute(0, 1, 3, 4, 2).contiguous()
            cls_preds = cls_preds.view(-1, self.num_anchors_per_location,
                                       self.num_class, H, W).permute(0, 1, 3, 4, 2).contiguous()
            box_preds = box_preds.view(batch_size, -1, self.code_size)
            cls_preds = cls_preds.view(batch_size, -1, self.num_class).unsqueeze(-1)

        if self.conv_dir_cls is not None:
            dir_cls_preds = self.conv_dir_cls(spatial_features_2d)
            if self.use_multihead:
                dir_cls_preds = dir_cls_preds.view(
                    -1, self.num_anchors_per_location, self.model_cfg.NUM_DIR_BINS, H, W).permute(0, 1, 3, 4, 2).contiguous()
                dir_cls_preds = dir_cls_preds.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS)
            else:
                dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
        
        else:
            dir_cls_preds = None

        ret_dict['cls_preds'] = cls_preds
        ret_dict['box_preds'] = box_preds
        ret_dict['dir_cls_preds'] = dir_cls_preds

        return ret_dict

class AnchorHeadMulti(AnchorHeadTemplate):
    def __init__(self, model_cfg, input_channels, num_class, grid_size, point_cloud_range, predict_boxes_when_training=True):
        super().__init__(
            model_cfg=model_cfg, num_class=num_class, grid_size=grid_size, point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training
        )
        self.model_cfg = model_cfg
        self.make_multihead(input_channels)
        

    def make_multihead(self, input_channels):
        rpn_head_cfgs = self.model_cfg.RPN_HEAD_CFGS
        rpn_heads = []
        class_names = []
        for rpn_head_cfg in rpn_head_cfgs:
            class_names.extend(rpn_head_cfg['head_cls_name'])
        for rpn_head_cfg in rpn_head_cfgs:
            num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)] for head_cls in rpn_head_cfg['head_cls_name']])
            rpn_head = SingleHead(self.model_cfg, input_channels, self.num_class, num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg)
            rpn_heads.append(rpn_head)
        self.rpn_heads = nn.ModuleList(rpn_heads)

    def forward(self, data_dict):
        spatial_features_2d = data_dict['spatial_features_2d']

        ret_dicts = []
        for rpn_head in self.rpn_heads:
            ret_dicts.append(rpn_head(spatial_features_2d))

        cls_preds = torch.cat([ret_dict['cls_preds'] for ret_dict in ret_dicts], dim=1)
        box_preds = torch.cat([ret_dict['box_preds'] for ret_dict in ret_dicts], dim=1)
        ret = {
            'cls_preds': cls_preds,
            'box_preds': box_preds,

        }
        if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', False):
            dir_cls_preds = torch.cat([ret_dict['dir_cls_preds'] for ret_dict in ret_dicts], dim=1)
            ret['dir_cls_preds'] = dir_cls_preds
        else:
            dir_cls_preds = None
 
        self.forward_ret_dict.update(ret)
       
        if self.training:
            targets_dict = self.assign_targets(
                gt_boxes=data_dict['gt_boxes']
            )
            self.forward_ret_dict.update(targets_dict)
        else:
            batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
                batch_size=data_dict['batch_size'],
                cls_preds=cls_preds, box_preds=box_preds, dir_cls_preds=dir_cls_preds
            )
            data_dict['batch_cls_preds'] = batch_cls_preds
            data_dict['batch_box_preds'] = batch_box_preds
            data_dict['cls_preds_normalized'] = False

        return data_dict