"...mmdet/git@developer.sourcefind.cn:OpenDAS/mmdeploy.git" did not exist on "686619677693badfa7d39fad13c6ec7bca8414fd"
Commit e0609869 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

merge with dev_multihead

parents 7d251220 0ac2901c
......@@ -7,8 +7,8 @@ class BaseBEVBackbone(nn.Module):
super().__init__()
self.model_cfg = model_cfg
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == \
len(self.model_cfg.NUM_FILTERS) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
layer_nums = self.model_cfg.LAYER_NUMS
layer_strides = self.model_cfg.LAYER_STRIDES
num_filters = self.model_cfg.NUM_FILTERS
......@@ -36,16 +36,16 @@ class BaseBEVBackbone(nn.Module):
nn.ReLU()
])
self.blocks.append(nn.Sequential(*cur_layers))
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx],
stride=upsample_strides[idx], bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
if len(upsample_strides) > 0:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx],
stride=upsample_strides[idx], bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
c_in = sum(num_upsample_filters)
if len(upsample_strides) > num_levels:
......@@ -73,12 +73,16 @@ class BaseBEVBackbone(nn.Module):
stride = int(spatial_features.shape[2] / x.shape[2])
ret_dict['spatial_features_%dx' % stride] = x
ups.append(self.deblocks[i](x))
if len(self.deblocks) > 0:
ups.append(self.deblocks[i](x))
else:
ups.append(x)
if len(ups) > 1:
x = torch.cat(ups, dim=1)
else:
elif len(ups) == 1:
x = ups[0]
if len(self.deblocks) > len(self.blocks):
x = self.deblocks[-1](x)
......
......@@ -22,8 +22,16 @@ class PFNLayer(nn.Module):
else:
self.linear = nn.Linear(in_channels, out_channels, bias=True)
self.part = 50000
def forward(self, inputs):
x = self.linear(inputs)
if inputs.shape[0] > self.part:
# nn.Linear performs randomly when batch size is too large
num_parts = inputs.shape[0] // self.part
part_linear_out = [self.linear(inputs[num_part*self.part:(num_part+1)*self.part]) for num_part in range(num_parts+1)]
x = torch.cat(part_linear_out, dim=0)
else:
x = self.linear(inputs)
total_points, voxel_points, channels = x.shape
x = self.norm(x.view(-1, channels)).view(total_points, voxel_points, channels) if self.use_norm else x
x = F.relu(x)
......
......@@ -4,6 +4,7 @@ from .anchor_head_template import AnchorHeadTemplate
from ..backbones_2d import BaseBEVBackbone
import torch
class SingleHead(BaseBEVBackbone):
def __init__(self, model_cfg, input_channels, num_class, num_anchors_per_location, code_size, encode_conv_cfg=None):
super().__init__(encode_conv_cfg, input_channels)
......@@ -30,7 +31,7 @@ class SingleHead(BaseBEVBackbone):
)
else:
self.conv_dir_cls = None
self.use_multihead = self.model_cfg.get('USE_MULTI_HEAD', False)
self.use_multihead = self.model_cfg.get('USE_MULTIHEAD', False)
self.init_weights()
def init_weights(self):
......@@ -55,8 +56,8 @@ class SingleHead(BaseBEVBackbone):
cls_preds = cls_preds.view(-1, self.num_anchors_per_location,
self.num_class, H, W).permute(0, 1, 3, 4, 2).contiguous()
box_preds = box_preds.view(batch_size, -1, self.code_size)
cls_preds = cls_preds.view(batch_size, -1, self.num_class).unsqueeze(-1)
cls_preds = cls_preds.view(batch_size, -1, self.num_class)
if self.conv_dir_cls is not None:
dir_cls_preds = self.conv_dir_cls(spatial_features_2d)
if self.use_multihead:
......@@ -75,14 +76,29 @@ class SingleHead(BaseBEVBackbone):
return ret_dict
class AnchorHeadMulti(AnchorHeadTemplate):
def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, predict_boxes_when_training=True):
def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range,
predict_boxes_when_training=True):
super().__init__(
model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size, point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training
model_cfg=model_cfg, num_class=num_class, class_names=class_names, grid_size=grid_size,
point_cloud_range=point_cloud_range, predict_boxes_when_training=predict_boxes_when_training
)
self.model_cfg = model_cfg
self.make_multihead(input_channels)
self.separate_multihead = self.model_cfg.get('SEPARATE_MULTIHEAD', False)
if self.model_cfg.get('SHARED_CONV_NUM_FILTER', None) is not None:
shared_conv_num_filter = self.model_cfg.SHARED_CONV_NUM_FILTER
self.shared_conv = nn.Sequential(
nn.Conv2d(input_channels, shared_conv_num_filter, 3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(shared_conv_num_filter, eps=1e-3, momentum=0.01),
nn.ReLU(),
)
else:
self.shared_conv = None
shared_conv_num_filter = input_channels
self.rpn_heads = None
self.make_multihead(shared_conv_num_filter)
def make_multihead(self, input_channels):
rpn_head_cfgs = self.model_cfg.RPN_HEAD_CFGS
......@@ -90,29 +106,37 @@ class AnchorHeadMulti(AnchorHeadTemplate):
class_names = []
for rpn_head_cfg in rpn_head_cfgs:
class_names.extend(rpn_head_cfg['HEAD_CLS_NAME'])
for rpn_head_cfg in rpn_head_cfgs:
num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)] for head_cls in rpn_head_cfg['HEAD_CLS_NAME']])
rpn_head = SingleHead(self.model_cfg, input_channels, self.num_class, num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg)
num_anchors_per_location = sum([self.num_anchors_per_location[class_names.index(head_cls)]
for head_cls in rpn_head_cfg['HEAD_CLS_NAME']])
rpn_head = SingleHead(
self.model_cfg, input_channels,
len(rpn_head_cfg['HEAD_CLS_NAME']) if self.separate_multihead else self.num_class,
num_anchors_per_location, self.box_coder.code_size, rpn_head_cfg
)
rpn_heads.append(rpn_head)
self.rpn_heads = nn.ModuleList(rpn_heads)
def forward(self, data_dict):
spatial_features_2d = data_dict['spatial_features_2d']
if self.shared_conv is not None:
spatial_features_2d = self.shared_conv(spatial_features_2d)
ret_dicts = []
for rpn_head in self.rpn_heads:
ret_dicts.append(rpn_head(spatial_features_2d))
cls_preds = torch.cat([ret_dict['cls_preds'] for ret_dict in ret_dicts], dim=1)
box_preds = torch.cat([ret_dict['box_preds'] for ret_dict in ret_dicts], dim=1)
cls_preds = [ret_dict['cls_preds'] for ret_dict in ret_dicts]
box_preds = [ret_dict['box_preds'] for ret_dict in ret_dicts]
ret = {
'cls_preds': cls_preds,
'box_preds': box_preds,
'cls_preds': cls_preds if self.separate_multihead else torch.cat(cls_preds, dim=1),
'box_preds': box_preds if self.separate_multihead else torch.cat(box_preds, dim=1),
}
if self.model_cfg.get('USE_DIRECTION_CLASSIFIER', False):
dir_cls_preds = torch.cat([ret_dict['dir_cls_preds'] for ret_dict in ret_dicts], dim=1)
ret['dir_cls_preds'] = dir_cls_preds
dir_cls_preds = [ret_dict['dir_cls_preds'] for ret_dict in ret_dicts]
ret['dir_cls_preds'] = dir_cls_preds if self.separate_multihead else torch.cat(dir_cls_preds, dim=1)
else:
dir_cls_preds = None
......@@ -133,3 +157,118 @@ class AnchorHeadMulti(AnchorHeadTemplate):
data_dict['cls_preds_normalized'] = False
return data_dict
def get_cls_layer_loss(self):
cls_preds = self.forward_ret_dict['cls_preds']
box_cls_labels = self.forward_ret_dict['box_cls_labels']
if not isinstance(cls_preds, list):
cls_preds = [cls_preds]
batch_size = int(cls_preds[0].shape[0])
cared = box_cls_labels >= 0 # [N, num_anchors]
positives = box_cls_labels > 0
negatives = box_cls_labels == 0
negative_cls_weights = negatives * 1.0
cls_weights = (negative_cls_weights + 1.0 * positives).float()
reg_weights = positives.float()
if self.num_class == 1:
# class agnostic
box_cls_labels[positives] = 1
pos_normalizer = positives.sum(1, keepdim=True).float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
cls_weights /= torch.clamp(pos_normalizer, min=1.0)
cls_targets = box_cls_labels * cared.type_as(box_cls_labels)
one_hot_targets = torch.zeros(
*list(cls_targets.shape), self.num_class + 1, dtype=cls_preds[0].dtype, device=cls_targets.device
)
one_hot_targets.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0)
one_hot_targets = one_hot_targets[..., 1:]
start_idx = c_idx = 0
cls_losses = 0
for idx, cls_pred in enumerate(cls_preds):
cur_num_class = self.rpn_heads[idx].num_class
cls_pred = cls_pred.view(batch_size, -1, cur_num_class)
if self.separate_multihead:
one_hot_target = one_hot_targets[:, start_idx:start_idx + cls_pred.shape[1], c_idx:c_idx+cur_num_class]
c_idx += cur_num_class
else:
one_hot_target = one_hot_targets[:, start_idx:start_idx+cls_pred.shape[1]]
cls_weight = cls_weights[:, start_idx:start_idx+cls_pred.shape[1]]
cls_loss_src = self.cls_loss_func(cls_pred, one_hot_target, weights=cls_weight) # [N, M]
cls_loss = cls_loss_src.sum() / batch_size
cls_loss = cls_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['cls_weight']
cls_losses += cls_loss
start_idx += cls_pred.shape[1]
assert start_idx == one_hot_targets.shape[1]
tb_dict = {
'rpn_loss_cls': cls_losses.item()
}
return cls_losses, tb_dict
def get_box_reg_layer_loss(self):
box_preds = self.forward_ret_dict['box_preds']
box_dir_cls_preds = self.forward_ret_dict.get('dir_cls_preds', None)
box_reg_targets = self.forward_ret_dict['box_reg_targets']
box_cls_labels = self.forward_ret_dict['box_cls_labels']
positives = box_cls_labels > 0
reg_weights = positives.float()
pos_normalizer = positives.sum(1, keepdim=True).float()
reg_weights /= torch.clamp(pos_normalizer, min=1.0)
if not isinstance(box_preds, list):
box_preds = [box_preds]
batch_size = int(box_preds[0].shape[0])
if isinstance(self.anchors, list):
if self.use_multihead:
anchors = torch.cat(
[anchor.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchor.shape[-1]) for anchor in
self.anchors], dim=0)
else:
anchors = torch.cat(self.anchors, dim=-3)
else:
anchors = self.anchors
anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1)
start_idx = 0
box_losses = 0
tb_dict = {}
for idx, box_pred in enumerate(box_preds):
box_pred = box_pred.view(batch_size, -1,
box_pred.shape[-1] // self.num_anchors_per_location if not self.use_multihead else
box_pred.shape[-1])
box_reg_target = box_reg_targets[:, start_idx:start_idx+box_pred.shape[1]]
reg_weight = reg_weights[:, start_idx:start_idx+box_pred.shape[1]]
# sin(a - b) = sinacosb-cosasinb
box_pred_sin, reg_target_sin = self.add_sin_difference(box_pred, box_reg_target)
loc_loss_src = self.reg_loss_func(box_pred_sin, reg_target_sin, weights=reg_weight) # [N, M]
loc_loss = loc_loss_src.sum() / batch_size
loc_loss = loc_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['loc_weight']
box_losses += loc_loss
tb_dict['rpn_loss_loc'] = tb_dict.get('rpn_loss_loc', 0) + loc_loss
if box_dir_cls_preds is not None:
if not isinstance(box_dir_cls_preds, list):
box_dir_cls_preds = [box_dir_cls_preds]
dir_targets = self.get_direction_target(
anchors, box_reg_targets,
dir_offset=self.model_cfg.DIR_OFFSET,
num_bins=self.model_cfg.NUM_DIR_BINS
)
box_dir_cls_pred = box_dir_cls_preds[idx]
dir_logit = box_dir_cls_pred.view(batch_size, -1, self.model_cfg.NUM_DIR_BINS)
weights = positives.type_as(dir_logit)
weights /= torch.clamp(weights.sum(-1, keepdim=True), min=1.0)
weight = weights[:, start_idx:start_idx+box_pred.shape[1]]
dir_target = dir_targets[:, start_idx:start_idx+box_pred.shape[1]]
dir_loss = self.dir_loss_func(dir_logit, dir_target, weights=weight)
dir_loss = dir_loss.sum() / batch_size
dir_loss = dir_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['dir_weight']
box_losses += dir_loss
tb_dict['rpn_loss_dir'] = tb_dict.get('rpn_loss_dir', 0) + dir_loss.item()
start_idx += box_pred.shape[1]
return box_losses, tb_dict
......@@ -14,7 +14,7 @@ class AnchorHeadTemplate(nn.Module):
self.num_class = num_class
self.class_names = class_names
self.predict_boxes_when_training = predict_boxes_when_training
self.use_multihead = self.model_cfg.get('USE_MULTI_HEAD', False)
self.use_multihead = self.model_cfg.get('USE_MULTIHEAD', False)
anchor_target_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG
self.box_coder = getattr(box_coder_utils, anchor_target_cfg.BOX_CODER)(
......@@ -28,7 +28,7 @@ class AnchorHeadTemplate(nn.Module):
anchor_ndim=self.box_coder.code_size
)
self.anchors = [x.cuda() for x in anchors]
self.target_assigner = self.get_target_assigner(anchor_target_cfg, anchor_generator_cfg)
self.target_assigner = self.get_target_assigner(anchor_target_cfg)
self.forward_ret_dict = {}
self.build_losses(self.model_cfg.LOSS_CONFIG)
......@@ -50,17 +50,17 @@ class AnchorHeadTemplate(nn.Module):
return anchors_list, num_anchors_per_location_list
def get_target_assigner(self, anchor_target_cfg, anchor_generator_cfg):
def get_target_assigner(self, anchor_target_cfg):
if anchor_target_cfg.NAME == 'ATSS':
target_assigner = ATSSTargetAssigner(
topk=anchor_target_cfg.TOPK,
box_coder=self.box_coder,
use_multihead=self.use_multihead,
match_height=anchor_target_cfg.MATCH_HEIGHT
)
elif anchor_target_cfg.NAME == 'AxisAlignedTargetAssigner':
target_assigner = AxisAlignedTargetAssigner(
anchor_target_cfg=anchor_target_cfg,
anchor_generator_cfg=anchor_generator_cfg,
model_cfg=self.model_cfg,
class_names=self.class_names,
box_coder=self.box_coder,
match_height=anchor_target_cfg.MATCH_HEIGHT
......@@ -93,7 +93,7 @@ class AnchorHeadTemplate(nn.Module):
"""
targets_dict = self.target_assigner.assign_targets(
self.anchors, gt_boxes, self.use_multihead
self.anchors, gt_boxes
)
return targets_dict
......@@ -124,8 +124,6 @@ class AnchorHeadTemplate(nn.Module):
one_hot_targets.scatter_(-1, cls_targets.unsqueeze(dim=-1).long(), 1.0)
cls_preds = cls_preds.view(batch_size, -1, self.num_class)
one_hot_targets = one_hot_targets[..., 1:]
# import pdb
# pdb.set_trace()
cls_loss_src = self.cls_loss_func(cls_preds, one_hot_targets, weights=cls_weights) # [N, M]
cls_loss = cls_loss_src.sum() / batch_size
......@@ -246,14 +244,14 @@ class AnchorHeadTemplate(nn.Module):
anchors = self.anchors
num_anchors = anchors.view(-1, anchors.shape[-1]).shape[0]
batch_anchors = anchors.view(1, -1, anchors.shape[-1]).repeat(batch_size, 1, 1)
batch_cls_preds = cls_preds.view(batch_size, num_anchors, -1).float()
batch_box_preds = box_preds.view(batch_size, num_anchors, -1)
batch_cls_preds = cls_preds.view(batch_size, num_anchors, -1).float() if not isinstance(cls_preds, list) else cls_preds
batch_box_preds = box_preds.view(batch_size, num_anchors, -1) if not isinstance(box_preds, list) else torch.cat(box_preds, dim=1).view(batch_size, num_anchors, -1)
batch_box_preds = self.box_coder.decode_torch(batch_box_preds, batch_anchors)
if dir_cls_preds is not None:
dir_offset = self.model_cfg.DIR_OFFSET
dir_limit_offset = self.model_cfg.DIR_LIMIT_OFFSET
dir_cls_preds = dir_cls_preds.view(batch_size, num_anchors, -1)
dir_cls_preds = dir_cls_preds.view(batch_size, num_anchors, -1) if not isinstance(dir_cls_preds, list) else torch.cat(dir_cls_preds, dim=1).view(batch_size, num_anchors, -1)
dir_labels = torch.max(dir_cls_preds, dim=-1)[1]
period = (2 * np.pi / self.model_cfg.NUM_DIR_BINS)
......
......@@ -4,8 +4,11 @@ from ....ops.iou3d_nms import iou3d_nms_utils
class AxisAlignedTargetAssigner(object):
def __init__(self, anchor_target_cfg, anchor_generator_cfg, class_names, box_coder, match_height=False):
def __init__(self, model_cfg, class_names, box_coder, match_height=False):
super().__init__()
anchor_generator_cfg = model_cfg.ANCHOR_GENERATOR_CONFIG
anchor_target_cfg = model_cfg.TARGET_ASSIGNER_CONFIG
self.box_coder = box_coder
self.match_height = match_height
self.class_names = class_names
......@@ -18,8 +21,17 @@ class AxisAlignedTargetAssigner(object):
for config in anchor_generator_cfg:
self.matched_thresholds[config['class_name']] = config['matched_threshold']
self.unmatched_thresholds[config['class_name']] = config['unmatched_threshold']
def assign_targets(self, all_anchors, gt_boxes_with_classes, use_multihead=False):
self.use_multihead = model_cfg.get('USE_MULTIHEAD', False)
self.seperate_multihead = model_cfg.get('SEPERATE_MULTIHEAD', False)
if self.seperate_multihead:
rpn_head_cfgs = model_cfg.RPN_HEAD_CFGS
self.gt_remapping = {}
for rpn_head_cfg in rpn_head_cfgs:
for idx, name in enumerate(rpn_head_cfg['HEAD_CLS_NAME']):
self.gt_remapping[name] = idx + 1
def assign_targets(self, all_anchors, gt_boxes_with_classes):
"""
Args:
all_anchors: [(N, 7), ...]
......@@ -47,21 +59,30 @@ class AxisAlignedTargetAssigner(object):
for anchor_class_name, anchors in zip(self.anchor_class_names, all_anchors):
mask = torch.tensor([self.class_names[c-1] == anchor_class_name for c in cur_gt_classes], dtype=torch.bool)
if use_multihead:
if self.use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
if self.seperate_multihead:
selected_classes = cur_gt_classes[mask].clone()
if len(selected_classes) > 0:
new_cls_id = self.gt_remapping[anchor_class_name]
selected_classes[:] = new_cls_id
else:
selected_classes = cur_gt_classes[mask]
else:
feature_map_size = anchors.shape[:3]
anchors = anchors.view(-1, anchors.shape[-1])
selected_classes = cur_gt_classes[mask]
single_target = self.assign_targets_single(
anchors,
cur_gt[mask],
gt_classes=cur_gt_classes[mask],
gt_classes=selected_classes,
matched_threshold=self.matched_thresholds[anchor_class_name],
unmatched_threshold=self.unmatched_thresholds[anchor_class_name]
)
target_list.append(single_target)
if use_multihead:
if self.use_multihead:
target_dict = {
'box_cls_labels': [t['box_cls_labels'].view(-1) for t in target_list],
'box_reg_targets': [t['box_reg_targets'].view(-1, self.box_coder.code_size) for t in target_list],
......
......@@ -6,7 +6,7 @@ from ..backbones_3d import vfe, pfe
from ..backbones_2d import map_to_bev
from ..model_utils.model_nms_utils import class_agnostic_nms
from ...ops.iou3d_nms import iou3d_nms_utils
import numpy as np
class Detector3DTemplate(nn.Module):
def __init__(self, model_cfg, num_class, dataset):
......@@ -183,24 +183,45 @@ class Detector3DTemplate(nn.Module):
assert batch_dict['batch_cls_preds'].shape.__len__() == 2
batch_mask = (batch_dict['batch_index'] == index)
else:
assert batch_dict['batch_cls_preds'].shape.__len__() == 3
if isinstance(batch_dict['batch_cls_preds'], list):
assert batch_dict['batch_cls_preds'][0].shape.__len__() == 3
else:
assert batch_dict['batch_cls_preds'].shape.__len__() == 3
batch_mask = index
box_preds = batch_dict['batch_box_preds'][batch_mask]
cls_preds = batch_dict['batch_cls_preds'][batch_mask]
cls_preds = batch_dict['batch_cls_preds'][batch_mask] if not isinstance(batch_dict['batch_cls_preds'], list) else [batch_cls_pred[batch_mask] for batch_cls_pred in batch_dict['batch_cls_preds']]
src_cls_preds = cls_preds
src_box_preds = box_preds
assert cls_preds.shape[1] in [1, self.num_class]
if isinstance(cls_preds, list):
assert cls_preds[0].shape[1] in [1, self.num_class]
else:
assert cls_preds.shape[1] in [1, self.num_class]
if not batch_dict['cls_preds_normalized']:
cls_preds = torch.sigmoid(cls_preds)
cls_preds = torch.sigmoid(cls_preds) if not isinstance(cls_preds, list) else [torch.sigmoid(cls_pred) for cls_pred in cls_preds]
if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS:
raise NotImplementedError
else:
cls_preds, label_preds = torch.max(cls_preds, dim=-1)
label_preds = batch_dict['roi_labels'][index] if batch_dict.get('has_class_labels', False) else label_preds + 1
if isinstance(cls_preds, list):
all_cls_preds = []
label_preds = []
rpn_head_cfgs = self.model_cfg.DENSE_HEAD.RPN_HEAD_CFGS
head_cls_names = [np.array(rpn_head_cfg['HEAD_CLS_NAME']) for rpn_head_cfg in rpn_head_cfgs]
for idx, cls_pred in enumerate(cls_preds):
pred_score, pred_head_label = torch.max(cls_pred, dim=-1)
pred_class_names = head_cls_names[idx][pred_head_label.cpu().numpy().astype(int)]
label_pred = [self.class_names.index(cls_name)+1 for cls_name in pred_class_names]
label_pred = torch.from_numpy(np.array(label_pred)).to(cls_pred.device).int()
all_cls_preds.append(pred_score)
label_preds.append(label_pred)
cls_preds = torch.cat(all_cls_preds, dim=0)
label_preds = torch.cat(label_preds, dim=0)
else:
cls_preds, label_preds = torch.max(cls_preds, dim=-1)
label_preds = batch_dict['roi_labels'][index] if batch_dict.get('has_class_labels', False) else label_preds + 1
selected, selected_scores = class_agnostic_nms(
box_scores=cls_preds, box_preds=box_preds,
......
......@@ -34,7 +34,8 @@ MODEL:
DIR_LIMIT_OFFSET: 0.0
NUM_DIR_BINS: 2
USE_MULTI_HEAD: True
USE_MULTIHEAD: True
SEPARATE_MULTIHEAD: True
ANCHOR_GENERATOR_CONFIG: [
{
'class_name': 'Car',
......@@ -52,7 +53,7 @@ MODEL:
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.6],
'align_center': False,
'feature_map_stride': 4,
'feature_map_stride': 8,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
},
......@@ -62,36 +63,38 @@ MODEL:
'anchor_rotations': [0, 1.57],
'anchor_bottom_heights': [-1.6],
'align_center': False,
'feature_map_stride': 4,
'feature_map_stride': 8,
'matched_threshold': 0.5,
'unmatched_threshold': 0.35
}
]
SHARED_CONV_NUM_FILTER: 64
RPN_HEAD_CFGS: [
{
'HEAD_CLS_NAME': ['Car'],
'LAYER_NUMS': [1],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [512],
'UPSAMPLE_STRIDES': [1],
'NUM_UPSAMPLE_FILTERS': [512]
'LAYER_NUMS': [],
'LAYER_STRIDES': [],
'NUM_FILTERS': [],
'UPSAMPLE_STRIDES': [],
'NUM_UPSAMPLE_FILTERS': []
},
{
'HEAD_CLS_NAME': ['Pedestrian'],
'LAYER_NUMS': [1],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [512],
'UPSAMPLE_STRIDES': [2],
'NUM_UPSAMPLE_FILTERS': [512]
'LAYER_NUMS': [],
'LAYER_STRIDES': [],
'NUM_FILTERS': [],
'UPSAMPLE_STRIDES': [],
'NUM_UPSAMPLE_FILTERS': []
},
{
'HEAD_CLS_NAME': ['Cyclist'],
'LAYER_NUMS': [1],
'LAYER_STRIDES': [1],
'NUM_FILTERS': [512],
'UPSAMPLE_STRIDES': [2],
'NUM_UPSAMPLE_FILTERS': [512]
'LAYER_NUMS': [],
'LAYER_STRIDES': [],
'NUM_FILTERS': [],
'UPSAMPLE_STRIDES': [],
'NUM_UPSAMPLE_FILTERS': []
}
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment