Commit be7951cf authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

[continue] support CenterHead / CenterPoint (1stage), add its WOD config

parent 52310ad9
import copy
import numpy as np
import torch
import torch.nn as nn
from torch.nn.init import kaiming_normal_
from ..model_utils import model_nms_utils
from ..model_utils import centernet_utils
from ...utils import loss_utils
class SeparateHead(nn.Module):
def __init__(self, input_channels, sep_head_dict, init_bias=-2.19, use_bias=False):
super().__init__()
self.sep_head_dict = sep_head_dict
for cur_name in self.sep_head_dict:
output_channels = self.sep_head_dict[cur_name]['out_channels']
num_conv = self.sep_head_dict[cur_name]['num_conv']
fc_list = []
for k in range(num_conv - 1):
fc_list.append(nn.Sequential(
nn.Conv2d(input_channels, input_channels, kernel_size=3, stride=1, padding=1, bias=use_bias),
nn.BatchNorm2d(input_channels),
nn.ReLU()
))
fc_list.append(nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=1, padding=1, bias=True))
fc = nn.Sequential(*fc_list)
if 'hm' in cur_name:
fc[-1].bias.data.fill_(init_bias)
else:
for m in fc.modules():
if isinstance(m, nn.Conv2d):
kaiming_normal_(m.weight.data)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias, 0)
self.__setattr__(cur_name, fc)
def forward(self, x):
ret_dict = {}
for cur_name in self.sep_head_dict:
ret_dict[cur_name] = self.__getattr__(cur_name)(x)
return ret_dict
class CenterHead(nn.Module):
def __init__(self, model_cfg, input_channels, num_class, class_names, grid_size, point_cloud_range, voxel_size,
predict_boxes_when_training=True):
super().__init__()
self.model_cfg = model_cfg
self.num_class = num_class
self.grid_size = grid_size
self.point_cloud_range = point_cloud_range
self.voxel_size = voxel_size
self.feature_map_stride = self.model_cfg.TARGET_ASSIGNER_CONFIG.get('FEATURE_MAP_STRIDE', None)
self.class_names = class_names
self.class_names_each_head = []
self.class_id_mapping_each_head = []
for cur_class_names in self.model_cfg.CLASS_NAMES_EACH_HEAD:
self.class_names_each_head.append([x for x in cur_class_names if x in class_names])
cur_class_id_mapping = torch.from_numpy(np.array(
[self.class_names.index(x) for x in cur_class_names if x in class_names]
)).cuda()
self.class_id_mapping_each_head.append(cur_class_id_mapping)
total_classes = sum([len(x) for x in self.class_names_each_head])
assert total_classes == len(self.class_names), f'class_names_each_head={self.class_names_each_head}'
self.shared_conv = nn.Sequential(
nn.Conv2d(
input_channels, self.model_cfg.SHARED_CONV_CHANNEL, 3, stride=1, padding=1,
bias=self.model_cfg.get('USE_BIAS_BEFORE_NORM', False)
),
nn.BatchNorm2d(self.model_cfg.SHARED_CONV_CHANNEL),
nn.ReLU(),
)
self.heads_list = nn.ModuleList()
self.separate_head_cfg = self.model_cfg.SEPARATE_HEAD_CFG
for idx, cur_class_names in enumerate(self.class_names_each_head):
cur_head_dict = copy.deepcopy(self.separate_head_cfg.HEAD_DICT)
cur_head_dict['hm'] = dict(out_channels=len(cur_class_names), num_conv=self.model_cfg.NUM_HM_CONV)
self.heads_list.append(
SeparateHead(
input_channels=self.model_cfg.SHARED_CONV_CHANNEL,
sep_head_dict=cur_head_dict,
init_bias=-2.19,
use_bias=self.model_cfg.get('USE_BIAS_BEFORE_NORM', False)
)
)
self.predict_boxes_when_training = predict_boxes_when_training
self.forward_ret_dict = {}
self.build_losses()
def build_losses(self):
self.add_module('hm_loss_func', loss_utils.FocalLossCenterNet())
self.add_module('reg_loss_func', loss_utils.RegLossCenterNet())
def assign_target_of_single_head(
self, num_classes, gt_boxes, feature_map_size, feature_map_stride, num_max_objs=500,
gaussian_overlap=0.1, min_radius=2
):
"""
Args:
gt_boxes: (N, 8)
feature_map_size: (2), [x, y]
Returns:
"""
heatmap = gt_boxes.new_zeros(num_classes, feature_map_size[1], feature_map_size[0])
ret_boxes = gt_boxes.new_zeros((num_max_objs, gt_boxes.shape[-1] - 1 + 1))
inds = gt_boxes.new_zeros(num_max_objs).long()
mask = gt_boxes.new_zeros(num_max_objs).long()
x, y, z = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2]
coord_x = (x - self.point_cloud_range[0]) / self.voxel_size[0] / feature_map_stride
coord_y = (y - self.point_cloud_range[1]) / self.voxel_size[1] / feature_map_stride
coord_x = torch.clamp(coord_x, min=0, max=feature_map_size[0] - 0.5) # bugfixed: 1e-6 does not work for center.int()
coord_y = torch.clamp(coord_y, min=0, max=feature_map_size[1] - 0.5) #
center = torch.cat((coord_x[:, None], coord_y[:, None]), dim=-1)
center_int = center.int()
center_int_float = center_int.float()
dx, dy, dz = gt_boxes[:, 3], gt_boxes[:, 4], gt_boxes[:, 5]
dx = dx / self.voxel_size[0] / feature_map_stride
dy = dy / self.voxel_size[1] / feature_map_stride
radius = centernet_utils.gaussian_radius(dx, dy, min_overlap=gaussian_overlap)
radius = torch.clamp_min(radius.int(), min=min_radius)
for k in range(min(num_max_objs, gt_boxes.shape[0])):
if dx[k] <= 0 or dy[k] <= 0:
continue
if not (0 <= center_int[k][0] <= feature_map_size[0] and 0 <= center_int[k][1] <= feature_map_size[1]):
continue
cur_class_id = (gt_boxes[k, -1] - 1).long()
centernet_utils.draw_gaussian_to_heatmap(heatmap[cur_class_id], center[k], radius[k].item())
inds[k] = center_int[k, 1] * feature_map_size[0] + center_int[k, 0]
mask[k] = 1
ret_boxes[k, 0:2] = center[k] - center_int_float[k].float()
ret_boxes[k, 2] = z[k]
ret_boxes[k, 3:6] = gt_boxes[k, 3:6].log()
ret_boxes[k, 6] = torch.cos(gt_boxes[k, 6])
ret_boxes[k, 7] = torch.sin(gt_boxes[k, 6])
if gt_boxes.shape[1] > 8:
ret_boxes[k, 8:] = gt_boxes[k, 7:-1]
return heatmap, ret_boxes, inds, mask
def assign_targets(self, gt_boxes, feature_map_size=None, **kwargs):
"""
Args:
gt_boxes: (B, M, 8)
range_image_polar: (B, 3, H, W)
feature_map_size: (2) [H, W]
spatial_cartesian: (B, 4, H, W)
Returns:
"""
feature_map_size = feature_map_size[::-1] # [H, W] ==> [x, y]
target_assigner_cfg = self.model_cfg.TARGET_ASSIGNER_CONFIG
# feature_map_size = self.grid_size[:2] // target_assigner_cfg.FEATURE_MAP_STRIDE
batch_size = gt_boxes.shape[0]
ret_dict = {
'heatmaps': [],
'target_boxes': [],
'inds': [],
'masks': [],
'heatmap_masks': []
}
all_names = np.array(['bg', *self.class_names])
for idx, cur_class_names in enumerate(self.class_names_each_head):
heatmap_list, target_boxes_list, inds_list, masks_list = [], [], [], []
for bs_idx in range(batch_size):
cur_gt_boxes = gt_boxes[bs_idx]
gt_class_names = all_names[cur_gt_boxes[:, -1].cpu().long().numpy()]
gt_boxes_single_head = []
for idx, name in enumerate(gt_class_names):
if name not in cur_class_names:
continue
temp_box = cur_gt_boxes[idx]
temp_box[-1] = cur_class_names.index(name) + 1
gt_boxes_single_head.append(temp_box[None, :])
if len(gt_boxes_single_head) == 0:
gt_boxes_single_head = cur_gt_boxes[:0, :]
else:
gt_boxes_single_head = torch.cat(gt_boxes_single_head, dim=0)
heatmap, ret_boxes, inds, mask = self.assign_target_of_single_head(
num_classes=len(cur_class_names), gt_boxes=gt_boxes_single_head,
feature_map_size=feature_map_size, feature_map_stride=target_assigner_cfg.FEATURE_MAP_STRIDE,
num_max_objs=target_assigner_cfg.NUM_MAX_OBJS,
gaussian_overlap=target_assigner_cfg.GAUSSIAN_OVERLAP,
min_radius=target_assigner_cfg.MIN_RADIUS,
)
heatmap_list.append(heatmap)
target_boxes_list.append(ret_boxes)
inds_list.append(inds)
masks_list.append(mask)
ret_dict['heatmaps'].append(torch.stack(heatmap_list, dim=0))
ret_dict['target_boxes'].append(torch.stack(target_boxes_list, dim=0))
ret_dict['inds'].append(torch.stack(inds_list, dim=0))
ret_dict['masks'].append(torch.stack(masks_list, dim=0))
return ret_dict
def sigmoid(self, x):
y = torch.clamp(x.sigmoid(), min=1e-4, max=1 - 1e-4)
return y
def get_loss(self):
pred_dicts = self.forward_ret_dict['pred_dicts']
target_dicts = self.forward_ret_dict['target_dicts']
tb_dict = {}
loss = 0
for idx, pred_dict in enumerate(pred_dicts):
pred_dict['hm'] = self.sigmoid(pred_dict['hm'])
hm_loss = self.hm_loss_func(pred_dict['hm'], target_dicts['heatmaps'][idx])
target_boxes = target_dicts['target_boxes'][idx]
pred_boxes = torch.cat([pred_dict[head_name] for head_name in self.separate_head_cfg.HEAD_ORDER], dim=1)
reg_loss = self.reg_loss_func(
pred_boxes, target_dicts['masks'][idx], target_dicts['inds'][idx], target_boxes
)
loc_loss = (reg_loss * reg_loss.new_tensor(self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['code_weights'])).sum()
loc_loss = loc_loss * self.model_cfg.LOSS_CONFIG.LOSS_WEIGHTS['loc_weight']
loss += hm_loss + loc_loss
tb_dict['hm_loss_head_%d' % idx] = hm_loss.item()
tb_dict['loc_loss_head_%d' % idx] = loc_loss.item()
tb_dict['rpn_loss'] = loss.item()
return loss, tb_dict
def generate_predicted_boxes(self, batch_size, pred_dicts):
post_process_cfg = self.model_cfg.POST_PROCESSING
post_center_limit_range = torch.tensor(post_process_cfg.POST_CENTER_LIMIT_RANGE).cuda().float()
ret_dict = [{
'pred_boxes': [],
'pred_scores': [],
'pred_labels': [],
} for k in range(batch_size)]
for idx, pred_dict in enumerate(pred_dicts):
batch_hm = pred_dict['hm'].sigmoid()
batch_center = pred_dict['center']
batch_center_z = pred_dict['center_z']
batch_dim = pred_dict['dim'].exp()
batch_rot_cos = pred_dict['rot'][:, 0].unsqueeze(dim=1)
batch_rot_sin = pred_dict['rot'][:, 1].unsqueeze(dim=1)
batch_vel = pred_dict['vel'] if 'vel' in self.separate_head_cfg.HEAD_ORDER else None
final_pred_dicts = centernet_utils.decode_bbox_from_heatmap(
heatmap=batch_hm, rot_cos=batch_rot_cos, rot_sin=batch_rot_sin,
center=batch_center, center_z=batch_center_z, dim=batch_dim, vel=batch_vel,
point_cloud_range=self.point_cloud_range, voxel_size=self.voxel_size,
feature_map_stride=self.feature_map_stride,
K=post_process_cfg.MAX_OBJ_PER_SAMPLE,
circle_nms=(post_process_cfg.NMS_CONFIG.NMS_TYPE == 'circle_nms'),
score_thresh=post_process_cfg.SCORE_THRESH,
post_center_limit_range=post_center_limit_range
)
for k, final_dict in enumerate(final_pred_dicts):
final_dict['pred_labels'] = self.class_id_mapping_each_head[idx][final_dict['pred_labels'].long()]
if post_process_cfg.NMS_CONFIG.NMS_TYPE != 'circle_nms':
selected, selected_scores = model_nms_utils.class_agnostic_nms(
box_scores=final_dict['pred_scores'], box_preds=final_dict['pred_boxes'],
nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=None
)
final_dict['pred_boxes'] = final_dict['pred_boxes'][selected]
final_dict['pred_scores'] = selected_scores
final_dict['pred_labels'] = final_dict['pred_labels'][selected]
ret_dict[k]['pred_boxes'].append(final_dict['pred_boxes'])
ret_dict[k]['pred_scores'].append(final_dict['pred_scores'])
ret_dict[k]['pred_labels'].append(final_dict['pred_labels'])
for k in range(batch_size):
ret_dict[k]['pred_boxes'] = torch.cat(ret_dict[k]['pred_boxes'], dim=0)
ret_dict[k]['pred_scores'] = torch.cat(ret_dict[k]['pred_scores'], dim=0)
ret_dict[k]['pred_labels'] = torch.cat(ret_dict[k]['pred_labels'], dim=0) + 1
return ret_dict
def forward(self, data_dict):
spatial_features_2d = data_dict['spatial_features_2d']
x = self.shared_conv(spatial_features_2d)
pred_dicts = []
for head in self.heads_list:
pred_dicts.append(head(x))
if self.training:
target_dict = self.assign_targets(
data_dict['gt_boxes'], feature_map_size=spatial_features_2d.size()[2:],
feature_map_stride=data_dict.get('spatial_features_2d_strides', None)
)
self.forward_ret_dict['target_dicts'] = target_dict
self.forward_ret_dict['pred_dicts'] = pred_dicts
if not self.training or self.predict_boxes_when_training:
final_box_dicts = self.generate_predicted_boxes(
data_dict['batch_size'], pred_dicts
)
data_dict['final_box_dicts'] = final_box_dicts
return data_dict
from .detector3d_template import Detector3DTemplate
class CenterPoint(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss_rpn, tb_dict = self.dense_head.get_loss()
tb_dict = {
'loss_rpn': loss_rpn.item(),
**tb_dict
}
loss = loss_rpn
return loss, tb_dict, disp_dict
def post_processing(self, batch_dict):
post_process_cfg = self.model_cfg.POST_PROCESSING
batch_size = batch_dict['batch_size']
final_pred_dict = batch_dict['final_box_dicts']
recall_dict = {}
for index in range(batch_size):
pred_boxes = final_pred_dict[index]['pred_boxes']
recall_dict = self.generate_recall_record(
box_preds=pred_boxes,
recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
thresh_list=post_process_cfg.RECALL_THRESH_LIST
)
return final_pred_dict, recall_dict
# This file is modified from https://github.com/tianweiy/CenterPoint
import torch
import torch.nn.functional as F
import numpy as np
import numba
def gaussian_radius(height, width, min_overlap=0.5):
"""
Args:
height: (N)
width: (N)
min_overlap:
Returns:
"""
a1 = 1
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = (b1 ** 2 - 4 * a1 * c1).sqrt()
r1 = (b1 + sq1) / 2
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = (b2 ** 2 - 4 * a2 * c2).sqrt()
r2 = (b2 + sq2) / 2
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = (b3 ** 2 - 4 * a3 * c3).sqrt()
r3 = (b3 + sq3) / 2
ret = torch.min(torch.min(r1, r2), r3)
return ret
def gaussian2D(shape, sigma=1):
m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m + 1, -n:n + 1]
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h
def draw_gaussian_to_heatmap(heatmap, center, radius, k=1, valid_mask=None):
diameter = 2 * radius + 1
gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = torch.from_numpy(
gaussian[radius - top:radius + bottom, radius - left:radius + right]
).to(heatmap.device).float()
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
if valid_mask is not None:
cur_valid_mask = valid_mask[y - top:y + bottom, x - left:x + right]
masked_gaussian = masked_gaussian * cur_valid_mask.float()
torch.max(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
return heatmap
def _nms(heat, kernel=3):
pad = (kernel - 1) // 2
hmax = F.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad)
keep = (hmax == heat).float()
return heat * keep
@numba.jit(nopython=True)
def circle_nms(dets, thresh):
x1 = dets[:, 0]
y1 = dets[:, 1]
scores = dets[:, 2]
order = scores.argsort()[::-1].astype(np.int32) # highest->lowest
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int32)
keep = []
for _i in range(ndets):
i = order[_i] # start with highest score box
if suppressed[i] == 1: # if any box have enough iou with this, remove it
continue
keep.append(i)
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
# calculate center distance between i and j box
dist = (x1[i] - x1[j]) ** 2 + (y1[i] - y1[j]) ** 2
# ovr = inter / areas[j]
if dist <= thresh:
suppressed[j] = 1
return keep
def _circle_nms(boxes, min_radius, post_max_size=83):
"""
NMS according to center distance
"""
keep = np.array(circle_nms(boxes.cpu().numpy(), thresh=min_radius))[:post_max_size]
keep = torch.from_numpy(keep).long().to(boxes.device)
return keep
def _gather_feat(feat, ind, mask=None):
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def _transpose_and_gather_feat(feat, ind):
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = _gather_feat(feat, ind)
return feat
def _topk(scores, K=40):
batch, num_class, height, width = scores.size()
topk_scores, topk_inds = torch.topk(scores.flatten(2, 3), K)
topk_inds = topk_inds % (height * width)
topk_ys = (topk_inds // width).float()
topk_xs = (topk_inds % width).int().float()
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K)
topk_classes = (topk_ind // K).int()
topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
return topk_score, topk_inds, topk_classes, topk_ys, topk_xs
def decode_bbox_from_heatmap(heatmap, rot_cos, rot_sin, center, center_z, dim,
point_cloud_range=None, voxel_size=None, feature_map_stride=None, vel=None, K=100,
circle_nms=False, score_thresh=None, post_center_limit_range=None):
batch_size, num_class, _, _ = heatmap.size()
if circle_nms:
# TODO: not checked yet
assert False, 'not checked yet'
heatmap = _nms(heatmap)
scores, inds, class_ids, ys, xs = _topk(heatmap, K=K)
center = _transpose_and_gather_feat(center, inds).view(batch_size, K, 2)
rot_sin = _transpose_and_gather_feat(rot_sin, inds).view(batch_size, K, 1)
rot_cos = _transpose_and_gather_feat(rot_cos, inds).view(batch_size, K, 1)
center_z = _transpose_and_gather_feat(center_z, inds).view(batch_size, K, 1)
dim = _transpose_and_gather_feat(dim, inds).view(batch_size, K, 3)
angle = torch.atan2(rot_sin, rot_cos)
xs = xs.view(batch_size, K, 1) + center[:, :, 0:1]
ys = ys.view(batch_size, K, 1) + center[:, :, 1:2]
xs = xs * feature_map_stride * voxel_size[0] + point_cloud_range[0]
ys = ys * feature_map_stride * voxel_size[1] + point_cloud_range[1]
box_part_list = [xs, ys, center_z, dim, angle]
if vel is not None:
vel = _transpose_and_gather_feat(vel, inds).view(batch_size, K, 2)
box_part_list.append(vel)
final_box_preds = torch.cat((box_part_list), dim=-1)
final_scores = scores.view(batch_size, K)
final_class_ids = class_ids.view(batch_size, K)
assert post_center_limit_range is not None
mask = (final_box_preds[..., :3] >= post_center_limit_range[:3]).all(2)
mask &= (final_box_preds[..., :3] <= post_center_limit_range[3:]).all(2)
if score_thresh is not None:
mask &= (final_scores > score_thresh)
ret_pred_dicts = []
for k in range(batch_size):
cur_mask = mask[k]
cur_boxes = final_box_preds[k, cur_mask]
cur_scores = final_scores[k, cur_mask]
cur_labels = final_class_ids[k, cur_mask]
if circle_nms:
assert False, 'not checked yet'
centers = cur_boxes[:, [0, 1]]
boxes = torch.cat((centers, scores.view(-1, 1)), dim=1)
keep = _circle_nms(boxes, min_radius=min_radius, post_max_size=nms_post_max_size)
cur_boxes = cur_boxes[keep]
cur_scores = cur_scores[keep]
cur_labels = cur_labels[keep]
ret_pred_dicts.append({
'pred_boxes': cur_boxes,
'pred_scores': cur_scores,
'pred_labels': cur_labels
})
return ret_pred_dicts
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment