"launch/dynamo-run/src/input/http.rs" did not exist on "ecf53ce2b38971dc9b5dde6f70d74cfc5b870c35"
Commit 635f1e94 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

all model codes refactoring for training/testing, support multiple 3D...

all model codes refactoring for training/testing, support multiple 3D detectors (SECOND, PartA2-Net (official release), PV-RCNN (official release))
parent 3fdecc87
import torch
from ....utils import common_utils
from ....ops.iou3d_nms import iou3d_nms_utils
class ATSSTargetAssigner(object):
"""
Reference: https://arxiv.org/abs/1912.02424
"""
def __init__(self, topk, box_coder, match_height=False):
self.topk = topk
self.box_coder = box_coder
self.match_height = match_height
def assign_targets(self, anchors_list, gt_boxes_with_classes, use_multihead=False):
"""
Args:
anchors: [(N, 7), ...]
gt_boxes: (B, M, 8)
Returns:
"""
if not isinstance(anchors_list, list):
anchors_list = [anchors_list]
single_set_of_anchor = True
else:
single_set_of_anchor = len(anchors_list) == 1
cls_labels_list, reg_targets_list, reg_weights_list = [], [], []
for anchors in anchors_list:
batch_size = gt_boxes_with_classes.shape[0]
gt_classes = gt_boxes_with_classes[:, :, 7]
gt_boxes = gt_boxes_with_classes[:, :, :7]
if use_multihead:
anchors = anchors.permute(3, 4, 0, 1, 2, 5).contiguous().view(-1, anchors.shape[-1])
else:
anchors = anchors.view(-1, anchors.shape[-1])
cls_labels, reg_targets, reg_weights = [], [], []
for k in range(batch_size):
cur_gt = gt_boxes[k]
cnt = cur_gt.__len__() - 1
while cnt > 0 and cur_gt[cnt].sum() == 0:
cnt -= 1
cur_gt = cur_gt[:cnt + 1]
cur_gt_classes = gt_classes[k][:cnt + 1]
cur_cls_labels, cur_reg_targets, cur_reg_weights = self.assign_targets_single(
anchors, cur_gt, cur_gt_classes
)
cls_labels.append(cur_cls_labels)
reg_targets.append(cur_reg_targets)
reg_weights.append(cur_reg_weights)
cls_labels = torch.stack(cls_labels, dim=0)
reg_targets = torch.stack(reg_targets, dim=0)
reg_weights = torch.stack(reg_weights, dim=0)
cls_labels_list.append(cls_labels)
reg_targets_list.append(reg_targets)
reg_weights_list.append(reg_weights)
if single_set_of_anchor:
ret_dict = {
'box_cls_labels': cls_labels_list[0],
'box_reg_targets': reg_targets_list[0],
'reg_weights': reg_weights_list[0]
}
else:
ret_dict = {
'box_cls_labels': torch.cat(cls_labels_list, dim=1),
'box_reg_targets': torch.cat(reg_targets_list, dim=1),
'reg_weights': torch.cat(reg_weights_list, dim=1)
}
return ret_dict
def assign_targets_single(self, anchors, gt_boxes, gt_classes):
"""
Args:
anchors: (N, 7) [x, y, z, dx, dy, dz, heading]
gt_boxes: (M, 7) [x, y, z, dx, dy, dz, heading]
gt_classes: (M)
Returns:
"""
num_anchor = anchors.shape[0]
num_gt = gt_boxes.shape[0]
# select topk anchors for each gt_boxes
if self.match_height:
ious = iou3d_nms_utils.boxes_iou3d_gpu(anchors[:, 0:7], gt_boxes[:, 0:7]) # (N, M)
else:
ious = iou3d_nms_utils.boxes_iou_bev(anchors[:, 0:7], gt_boxes[:, 0:7])
distance = (anchors[:, None, 0:3] - gt_boxes[None, :, 0:3]).norm(dim=-1) # (N, M)
_, topk_idxs = distance.topk(self.topk, dim=0, largest=False) # (K, M)
candidate_ious = ious[topk_idxs, torch.arange(num_gt)] # (K, M)
iou_mean_per_gt = candidate_ious.mean(dim=0)
iou_std_per_gt = candidate_ious.std(dim=0)
iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt + 1e-6
is_pos = candidate_ious >= iou_thresh_per_gt[None, :] # (K, M)
# check whether anchor_center in gt_boxes, only check BEV x-y axes
candidate_anchors = anchors[topk_idxs.view(-1)] # (KxM, 7)
gt_boxes_of_each_anchor = gt_boxes[:, :].repeat(self.topk, 1) # (KxM, 7)
xyz_local = candidate_anchors[:, 0:3] - gt_boxes_of_each_anchor[:, 0:3]
xyz_local = common_utils.rotate_points_along_z(
xyz_local[:, None, :], -gt_boxes_of_each_anchor[:, 6]
).squeeze(dim=1)
xy_local = xyz_local[:, 0:2]
lw = gt_boxes_of_each_anchor[:, 3:5][:, [1, 0]] # bugfixed: w ==> y, l ==> x in local coords
is_in_gt = ((xy_local <= lw / 2) & (xy_local >= -lw / 2)).all(dim=-1).view(-1, num_gt) # (K, M)
is_pos = is_pos & is_in_gt # (K, M)
for ng in range(num_gt):
topk_idxs[:, ng] += ng * num_anchor
# select the highest IoU if an anchor box is assigned with multiple gt_boxes
INF = -0x7FFFFFFF
ious_inf = torch.full_like(ious, INF).t().contiguous().view(-1) # (MxN)
index = topk_idxs.view(-1)[is_pos.view(-1)]
ious_inf[index] = ious.t().contiguous().view(-1)[index]
ious_inf = ious_inf.view(num_gt, -1).t() # (N, M)
anchors_to_gt_values, anchors_to_gt_indexs = ious_inf.max(dim=1)
# match the gt_boxes to the anchors which have maximum iou with them
max_iou_of_each_gt, argmax_iou_of_each_gt = ious.max(dim=0)
anchors_to_gt_indexs[argmax_iou_of_each_gt] = torch.arange(0, num_gt, device=ious.device)
anchors_to_gt_values[argmax_iou_of_each_gt] = max_iou_of_each_gt
cls_labels = gt_classes[anchors_to_gt_indexs]
cls_labels[anchors_to_gt_values == INF] = 0
matched_gts = gt_boxes[anchors_to_gt_indexs]
pos_mask = cls_labels > 0
reg_targets = matched_gts.new_zeros((num_anchor, self.box_coder.code_size))
reg_weights = matched_gts.new_zeros(num_anchor)
if pos_mask.sum() > 0:
reg_targets[pos_mask > 0] = self.box_coder.encode_torch(matched_gts[pos_mask > 0], anchors[pos_mask > 0])
reg_weights[pos_mask] = 1.0
return cls_labels, reg_targets, reg_weights
from .detector3d_template import Detector3DTemplate
class PartA2Net(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss_rpn, tb_dict = self.dense_head.get_loss()
loss_point, tb_dict = self.point_head.get_loss(tb_dict)
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)
loss = loss_rpn + loss_point + loss_rcnn
return loss, tb_dict, disp_dict
from .detector3d_template import Detector3DTemplate
from .second_net import SECONDNet
from .PartA2_net import PartA2Net
from .pv_rcnn import PVRCNN
__all__ = {
'Detector3DTemplate': Detector3DTemplate,
'SECONDNet': SECONDNet,
'PartA2Net': PartA2Net,
'PVRCNN': PVRCNN
}
def build_detector(model_cfg, num_class, dataset):
model = __all__[model_cfg.NAME](
model_cfg=model_cfg, num_class=num_class, dataset=dataset
)
return model
import torch
import os
import torch.nn as nn
from .. import backbones_3d, backbones_2d, dense_heads, roi_heads
from ..backbones_3d import vfe, pfe
from ..backbones_2d import map_to_bev
from ..model_utils.model_nms_utils import class_agnostic_nms
from ...ops.iou3d_nms import iou3d_nms_utils
class Detector3DTemplate(nn.Module):
def __init__(self, model_cfg, num_class, dataset):
super().__init__()
self.model_cfg = model_cfg
self.num_class = num_class
self.dataset = dataset
self.register_buffer('global_step', torch.LongTensor(1).zero_())
self.module_topology = [
'vfe', 'backbone_3d', 'map_to_bev_module', 'pfe',
'backbone_2d', 'dense_head', 'point_head', 'roi_head'
]
@property
def mode(self):
return 'TRAIN' if self.training else 'TEST'
def update_global_step(self):
self.global_step += 1
def build_networks(self):
model_info_dict = {
'module_list': [],
'num_rawpoint_features': self.dataset.point_feature_encoder.num_point_features,
'grid_size': self.dataset.grid_size,
'point_cloud_range': self.dataset.point_cloud_range,
'voxel_size': self.dataset.voxel_size
}
for module_name in self.module_topology:
module, model_info_dict = getattr(self, 'build_%s' % module_name)(
model_info_dict=model_info_dict
)
self.add_module(module_name, module)
return model_info_dict['module_list']
def build_vfe(self, model_info_dict):
if self.model_cfg.get('VFE', None) is None:
return None, model_info_dict
vfe_module = vfe.__all__[self.model_cfg.VFE.NAME](
model_cfg=self.model_cfg.VFE,
num_point_features=model_info_dict['num_rawpoint_features']
)
model_info_dict['num_point_features'] = vfe_module.get_output_feature_dim()
model_info_dict['module_list'].append(vfe_module)
return vfe_module, model_info_dict
def build_backbone_3d(self, model_info_dict):
if self.model_cfg.get('BACKBONE_3D', None) is None:
return None, model_info_dict
backbone_3d_module = backbones_3d.__all__[self.model_cfg.BACKBONE_3D.NAME](
model_cfg=self.model_cfg.BACKBONE_3D,
input_channels=model_info_dict['num_point_features'],
grid_size=model_info_dict['grid_size'],
voxel_size=model_info_dict['voxel_size'],
point_cloud_range=model_info_dict['point_cloud_range']
)
model_info_dict['module_list'].append(backbone_3d_module)
model_info_dict['num_point_features'] = backbone_3d_module.num_point_features
return backbone_3d_module, model_info_dict
def build_map_to_bev_module(self, model_info_dict):
if self.model_cfg.get('MAP_TO_BEV', None) is None:
return None, model_info_dict
map_to_bev_module = map_to_bev.__all__[self.model_cfg.MAP_TO_BEV.NAME](
model_cfg=self.model_cfg.MAP_TO_BEV
)
model_info_dict['module_list'].append(map_to_bev_module)
model_info_dict['num_bev_features'] = map_to_bev_module.num_bev_features
return map_to_bev_module, model_info_dict
def build_backbone_2d(self, model_info_dict):
if self.model_cfg.get('BACKBONE_2D', None) is None:
return None, model_info_dict
backbone_2d_module = backbones_2d.__all__[self.model_cfg.BACKBONE_2D.NAME](
model_cfg=self.model_cfg.BACKBONE_2D,
input_channels=model_info_dict['num_bev_features']
)
model_info_dict['module_list'].append(backbone_2d_module)
model_info_dict['num_bev_features'] = backbone_2d_module.num_bev_features
return backbone_2d_module, model_info_dict
def build_pfe(self, model_info_dict):
if self.model_cfg.get('PFE', None) is None:
return None, model_info_dict
pfe_module = pfe.__all__[self.model_cfg.PFE.NAME](
model_cfg=self.model_cfg.PFE,
voxel_size=model_info_dict['voxel_size'],
point_cloud_range=model_info_dict['point_cloud_range'],
num_bev_features=model_info_dict['num_bev_features'],
num_rawpoint_features=model_info_dict['num_rawpoint_features']
)
model_info_dict['module_list'].append(pfe_module)
model_info_dict['num_point_features'] = pfe_module.num_point_features
model_info_dict['num_point_features_before_fusion'] = pfe_module.num_point_features_before_fusion
return pfe_module, model_info_dict
def build_dense_head(self, model_info_dict):
if self.model_cfg.get('DENSE_HEAD', None) is None:
return None, model_info_dict
dense_head_module = dense_heads.__all__[self.model_cfg.DENSE_HEAD.NAME](
model_cfg=self.model_cfg.DENSE_HEAD,
input_channels=model_info_dict['num_bev_features'],
num_class=self.num_class if not self.model_cfg.DENSE_HEAD.CLASS_AGNOSTIC else 1,
grid_size=model_info_dict['grid_size'],
point_cloud_range=model_info_dict['point_cloud_range'],
predict_boxes_when_training=self.model_cfg.get('ROI_HEAD', False)
)
model_info_dict['module_list'].append(dense_head_module)
return dense_head_module, model_info_dict
def build_point_head(self, model_info_dict):
if self.model_cfg.get('POINT_HEAD', None) is None:
return None, model_info_dict
if self.model_cfg.POINT_HEAD.get('USE_POINT_FEATURES_BEFORE_FUSION', False):
num_point_features = model_info_dict['num_point_features_before_fusion']
else:
num_point_features = model_info_dict['num_point_features']
point_head_module = dense_heads.__all__[self.model_cfg.POINT_HEAD.NAME](
model_cfg=self.model_cfg.POINT_HEAD,
input_channels=num_point_features,
num_class=self.num_class if not self.model_cfg.POINT_HEAD.CLASS_AGNOSTIC else 1,
)
model_info_dict['module_list'].append(point_head_module)
return point_head_module, model_info_dict
def build_roi_head(self, model_info_dict):
if self.model_cfg.get('ROI_HEAD', None) is None:
return None, model_info_dict
point_head_module = roi_heads.__all__[self.model_cfg.ROI_HEAD.NAME](
model_cfg=self.model_cfg.ROI_HEAD,
input_channels=model_info_dict['num_point_features'],
num_class=self.num_class if not self.model_cfg.POINT_HEAD.CLASS_AGNOSTIC else 1,
)
model_info_dict['module_list'].append(point_head_module)
return point_head_module, model_info_dict
def forward(self, **kwargs):
raise NotImplementedError
def post_processing(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1)
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
cls_preds_normalized: indicate whether batch_cls_preds is normalized
batch_index: optional (N1+N2+...)
roi_labels: (B, num_rois) 1 .. num_classes
Returns:
"""
post_process_cfg = self.model_cfg.POST_PROCESSING
batch_size = batch_dict['batch_size']
recall_dict = {}
pred_dicts = []
for index in range(batch_size):
if batch_dict.get('batch_index', None) is not None:
assert batch_dict['batch_cls_preds'].shape.__len__() == 2
batch_mask = (batch_dict['batch_index'] == index)
else:
assert batch_dict['batch_cls_preds'].shape.__len__() == 3
batch_mask = index
box_preds = batch_dict['batch_box_preds'][batch_mask]
cls_preds = batch_dict['batch_cls_preds'][batch_mask]
src_cls_preds = cls_preds
src_box_preds = box_preds
assert cls_preds.shape[1] in [1, self.num_class]
if not batch_dict['cls_preds_normalized']:
cls_preds = torch.sigmoid(cls_preds)
if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS:
raise NotImplementedError
else:
cls_preds, label_preds = torch.max(cls_preds, dim=-1)
label_preds = batch_dict['roi_labels'][index] if batch_dict.get('has_class_labels', False) else label_preds + 1
selected, selected_scores = class_agnostic_nms(
box_scores=cls_preds, box_preds=box_preds,
nms_config=post_process_cfg.NMS_CONFIG,
score_thresh=post_process_cfg.SCORE_THRESH
)
if post_process_cfg.OUTPUT_RAW_SCORE:
max_cls_preds, _ = torch.max(src_cls_preds, dim=-1)
selected_scores = max_cls_preds[selected]
final_scores = selected_scores
final_labels = label_preds[selected]
final_boxes = box_preds[selected]
recall_dict = self.generate_recall_record(
box_preds=final_boxes if 'rois' not in batch_dict else src_box_preds,
recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
thresh_list=post_process_cfg.RECALL_THRESH_LIST
)
record_dict = {
'pred_boxes': final_boxes,
'pred_scores': final_scores,
'pred_labels': final_labels
}
pred_dicts.append(record_dict)
return pred_dicts, recall_dict
@staticmethod
def generate_recall_record(box_preds, recall_dict, batch_index, data_dict=None, thresh_list=None):
if 'gt_boxes' not in data_dict:
return recall_dict
rois = data_dict['rois'][batch_index] if 'rois' in data_dict else None
gt_boxes = data_dict['gt_boxes'][batch_index]
if recall_dict.__len__() == 0:
recall_dict = {'gt': 0}
for cur_thresh in thresh_list:
recall_dict['roi_%s' % (str(cur_thresh))] = 0
recall_dict['rcnn_%s' % (str(cur_thresh))] = 0
cur_gt = gt_boxes
k = cur_gt.__len__() - 1
while k > 0 and cur_gt[k].sum() == 0:
k -= 1
cur_gt = cur_gt[:k + 1]
if cur_gt.sum() > 0:
if box_preds.shape[0] > 0:
iou3d_rcnn = iou3d_nms_utils.boxes_iou3d_gpu(box_preds, cur_gt[:, 0:7])
else:
iou3d_rcnn = torch.zeros((0, cur_gt.shape[0]))
if rois is not None:
iou3d_roi = iou3d_nms_utils.boxes_iou3d_gpu(rois, cur_gt[:, 0:7])
for cur_thresh in thresh_list:
if iou3d_rcnn.shape[0] == 0:
recall_dict['rcnn_%s' % str(cur_thresh)] += 0
else:
rcnn_recalled = (iou3d_rcnn.max(dim=0)[0] > cur_thresh).sum().item()
recall_dict['rcnn_%s' % str(cur_thresh)] += rcnn_recalled
if rois is not None:
roi_recalled = (iou3d_roi.max(dim=0)[0] > cur_thresh).sum().item()
recall_dict['roi_%s' % str(cur_thresh)] += roi_recalled
recall_dict['gt'] += cur_gt.shape[0]
else:
gt_iou = box_preds.new_zeros(box_preds.shape[0])
return recall_dict
def load_params_from_file(self, filename, logger, to_cpu=False):
if not os.path.isfile(filename):
raise FileNotFoundError
logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))
loc_type = torch.device('cpu') if to_cpu else None
checkpoint = torch.load(filename, map_location=loc_type)
model_state_disk = checkpoint['model_state']
if 'version' in checkpoint:
logger.info('==> Checkpoint trained from version: %s' % checkpoint['version'])
update_model_state = {}
for key, val in model_state_disk.items():
if key in self.state_dict() and self.state_dict()[key].shape == model_state_disk[key].shape:
update_model_state[key] = val
# logger.info('Update weight %s: %s' % (key, str(val.shape)))
state_dict = self.state_dict()
state_dict.update(update_model_state)
self.load_state_dict(state_dict)
for key in state_dict:
if key not in update_model_state:
logger.info('Not updated weight %s: %s' % (key, str(state_dict[key].shape)))
logger.info('==> Done (loaded %d/%d)' % (len(update_model_state), len(self.state_dict())))
def load_params_with_optimizer(self, filename, to_cpu=False, optimizer=None, logger=None):
if not os.path.isfile(filename):
raise FileNotFoundError
logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))
loc_type = torch.device('cpu') if to_cpu else None
checkpoint = torch.load(filename, map_location=loc_type)
epoch = checkpoint.get('epoch', -1)
it = checkpoint.get('it', 0.0)
self.load_state_dict(checkpoint['model_state'])
if optimizer is not None:
if 'optimizer_state' in checkpoint and checkpoint['optimizer_state'] is not None:
logger.info('==> Loading optimizer parameters from checkpoint %s to %s'
% (filename, 'CPU' if to_cpu else 'GPU'))
optimizer.load_state_dict(checkpoint['optimizer_state'])
else:
assert filename[-4] == '.', filename
src_file, ext = filename[:-4], filename[-3:]
optimizer_filename = '%s_optim.%s' % (src_file, ext)
if os.path.exists(optimizer_filename):
optimizer_ckpt = torch.load(optimizer_filename, map_location=loc_type)
optimizer.load_state_dict(optimizer_ckpt['optimizer_state'])
if 'version' in checkpoint:
print('==> Checkpoint trained from version: %s' % checkpoint['version'])
logger.info('==> Done')
return it, epoch
from .detector3d_template import Detector3DTemplate
class PVRCNN(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss_rpn, tb_dict = self.dense_head.get_loss()
loss_point, tb_dict = self.point_head.get_loss(tb_dict)
loss_rcnn, tb_dict = self.roi_head.get_loss(tb_dict)
loss = loss_rpn + loss_point + loss_rcnn
return loss, tb_dict, disp_dict
from .detector3d_template import Detector3DTemplate
class SECONDNet(Detector3DTemplate):
def __init__(self, model_cfg, num_class, dataset):
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
self.module_list = self.build_networks()
def forward(self, batch_dict):
for cur_module in self.module_list:
batch_dict = cur_module(batch_dict)
if self.training:
loss, tb_dict, disp_dict = self.get_training_loss()
ret_dict = {
'loss': loss
}
return ret_dict, tb_dict, disp_dict
else:
pred_dicts, recall_dicts = self.post_processing(batch_dict)
return pred_dicts, recall_dicts
def get_training_loss(self):
disp_dict = {}
loss_rpn, tb_dict = self.dense_head.get_loss()
tb_dict = {
'loss_rpn': loss_rpn.item(),
**tb_dict
}
loss = loss_rpn
return loss, tb_dict, disp_dict
import torch
from ...ops.iou3d_nms import iou3d_nms_utils
def class_agnostic_nms(box_scores, box_preds, nms_config, score_thresh=None):
src_box_scores = box_scores
if score_thresh is not None:
scores_mask = (box_scores >= score_thresh)
box_scores = box_scores[scores_mask]
box_preds = box_preds[scores_mask]
selected = []
if box_scores.shape[0] > 0:
box_scores_nms, indices = torch.topk(box_scores, k=min(nms_config.NMS_PRE_MAXSIZE, box_scores.shape[0]))
boxes_for_nms = box_preds[indices]
keep_idx, selected_scores = getattr(iou3d_nms_utils, nms_config.NMS_TYPE)(
boxes_for_nms, box_scores_nms, nms_config.NMS_THRESH, **nms_config
)
selected = indices[keep_idx[:nms_config.NMS_POST_MAXSIZE]]
if score_thresh is not None:
original_idxs = scores_mask.nonzero().view(-1)
selected = original_idxs[selected]
return selected, src_box_scores[selected]
from .roi_head_template import RoIHeadTemplate
from .partA2_head import PartA2FCHead
from .pvrcnn_head import PVRCNNHead
__all__ = {
'RoIHeadTemplate': RoIHeadTemplate,
'PartA2FCHead': PartA2FCHead,
'PVRCNNHead': PVRCNNHead
}
import torch
import torch.nn as nn
import spconv
import numpy as np
from .roi_head_template import RoIHeadTemplate
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
class PartA2FCHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg
self.SA_modules = nn.ModuleList()
block = self.post_act_block
c0 = self.model_cfg.ROI_AWARE_POOL.NUM_FEATURES // 2
self.conv_part = spconv.SparseSequential(
block(4, 64, 3, padding=1, indice_key='rcnn_subm1'),
block(64, c0, 3, padding=1, indice_key='rcnn_subm1_1'),
)
self.conv_rpn = spconv.SparseSequential(
block(input_channels, 64, 3, padding=1, indice_key='rcnn_subm2'),
block(64, c0, 3, padding=1, indice_key='rcnn_subm1_2'),
)
shared_fc_list = []
pool_size = self.model_cfg.ROI_AWARE_POOL.POOL_SIZE
pre_channel = self.model_cfg.ROI_AWARE_POOL.NUM_FEATURES * pool_size * pool_size * pool_size
for k in range(0, self.model_cfg.SHARED_FC.__len__()):
shared_fc_list.extend([
nn.Conv1d(pre_channel, self.model_cfg.SHARED_FC[k], kernel_size=1, bias=False),
nn.BatchNorm1d(self.model_cfg.SHARED_FC[k]),
nn.ReLU()
])
pre_channel = self.model_cfg.SHARED_FC[k]
if k != self.model_cfg.SHARED_FC.__len__() - 1 and self.model_cfg.DP_RATIO > 0:
shared_fc_list.append(nn.Dropout(self.model_cfg.DP_RATIO))
self.shared_fc_layer = nn.Sequential(*shared_fc_list)
self.cls_layers = self.make_fc_layers(
input_channels=pre_channel, output_channels=self.num_class, fc_list=self.model_cfg.CLS_FC
)
self.reg_layers = self.make_fc_layers(
input_channels=pre_channel,
output_channels=self.box_coder.code_size * self.num_class,
fc_list=self.model_cfg.REG_FC
)
self.roiaware_pool3d_layer = roiaware_pool3d_utils.RoIAwarePool3d(
out_size=self.model_cfg.ROI_AWARE_POOL.POOL_SIZE,
max_pts_each_voxel=self.model_cfg.ROI_AWARE_POOL.MAX_POINTS_PER_VOXEL
)
self.init_weights(weight_init='xavier')
def init_weights(self, weight_init='xavier'):
if weight_init == 'kaiming':
init_func = nn.init.kaiming_normal_
elif weight_init == 'xavier':
init_func = nn.init.xavier_normal_
elif weight_init == 'normal':
init_func = nn.init.normal_
else:
raise NotImplementedError
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
if weight_init == 'normal':
init_func(m.weight, mean=0, std=0.001)
else:
init_func(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.reg_layers[-1].weight, mean=0, std=0.001)
def post_act_block(self, in_channels, out_channels, kernel_size, indice_key, stride=1, padding=0, conv_type='subm'):
if conv_type == 'subm':
m = spconv.SparseSequential(
spconv.SubMConv3d(in_channels, out_channels, kernel_size, bias=False, indice_key=indice_key),
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
nn.ReLU(),
)
elif conv_type == 'spconv':
m = spconv.SparseSequential(
spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
bias=False, indice_key=indice_key),
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
nn.ReLU(),
)
elif conv_type == 'inverseconv':
m = spconv.SparseSequential(
spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size,
indice_key=indice_key, bias=False),
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
nn.ReLU(),
)
else:
raise NotImplementedError
return m
def roiaware_pool(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
rois: (B, num_rois, 7 + C)
point_coords: (num_points, 4) [bs_idx, x, y, z]
point_features: (num_points, C)
point_cls_scores: (N1 + N2 + N3 + ..., 1)
point_part_offset: (N1 + N2 + N3 + ..., 3)
Returns:
"""
batch_size = batch_dict['batch_size']
batch_idx = batch_dict['point_coords'][:, 0]
point_coords = batch_dict['point_coords'][:, 1:4]
point_features = batch_dict['point_features']
part_features = torch.cat((
batch_dict['point_part_offset'], batch_dict['point_cls_scores'].view(-1, 1).detach()
), dim=1)
part_features[part_features[:, -1] < self.model_cfg.SEG_MASK_SCORE_THRESH, 0:3] = 0
rois = batch_dict['rois']
pooled_part_features_list, pooled_rpn_features_list = [], []
for bs_idx in range(batch_size):
bs_mask = (batch_idx == bs_idx)
cur_point_coords = point_coords[bs_mask]
cur_part_features = part_features[bs_mask]
cur_rpn_features = point_features[bs_mask]
cur_roi = rois[bs_idx][:, 0:7].contiguous() # (N, 7)
pooled_part_features = self.roiaware_pool3d_layer.forward(
cur_roi, cur_point_coords, cur_part_features, pool_method='avg'
) # (N, out_x, out_y, out_z, 4)
pooled_rpn_features = self.roiaware_pool3d_layer.forward(
cur_roi, cur_point_coords, cur_rpn_features, pool_method='max'
) # (N, out_x, out_y, out_z, C)
pooled_part_features_list.append(pooled_part_features)
pooled_rpn_features_list.append(pooled_rpn_features)
pooled_part_features = torch.cat(pooled_part_features_list, dim=0) # (B * N, out_x, out_y, out_z, 4)
pooled_rpn_features = torch.cat(pooled_rpn_features_list, dim=0) # (B * N, out_x, out_y, out_z, C)
return pooled_part_features, pooled_rpn_features
@staticmethod
def fake_sparse_idx(sparse_idx, batch_size_rcnn):
print('Warning: Sparse_Idx_Shape(%s) \r' % (str(sparse_idx.shape)), end='', flush=True)
# at most one sample is non-empty, then fake the first voxels of each sample(BN needs at least
# two values each channel) as non-empty for the below calculation
sparse_idx = sparse_idx.new_zeros((batch_size_rcnn, 3))
bs_idxs = torch.arange(batch_size_rcnn).type_as(sparse_idx).view(-1, 1)
sparse_idx = torch.cat((bs_idxs, sparse_idx), dim=1)
return sparse_idx
def forward(self, batch_dict):
"""
Args:
batch_dict:
Returns:
"""
targets_dict = self.proposal_layer(
batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
)
if self.training:
targets_dict = self.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
# RoI aware pooling
pooled_part_features, pooled_rpn_features = self.roiaware_pool(batch_dict)
batch_size_rcnn = pooled_part_features.shape[0] # (B * N, out_x, out_y, out_z, 4)
# transform to sparse tensors
sparse_shape = np.array(pooled_part_features.shape[1:4], dtype=np.int32)
sparse_idx = pooled_part_features.sum(dim=-1).nonzero() # (non_empty_num, 4) ==> [bs_idx, x_idx, y_idx, z_idx]
if sparse_idx.shape[0] < 3:
sparse_idx = self.fake_sparse_idx(sparse_idx, batch_size_rcnn)
if self.training:
# these are invalid samples
targets_dict['rcnn_cls_labels'].fill_(-1)
targets_dict['reg_valid_mask'].fill_(-1)
part_features = pooled_part_features[sparse_idx[:, 0], sparse_idx[:, 1], sparse_idx[:, 2], sparse_idx[:, 3]]
rpn_features = pooled_rpn_features[sparse_idx[:, 0], sparse_idx[:, 1], sparse_idx[:, 2], sparse_idx[:, 3]]
coords = sparse_idx.int()
part_features = spconv.SparseConvTensor(part_features, coords, sparse_shape, batch_size_rcnn)
rpn_features = spconv.SparseConvTensor(rpn_features, coords, sparse_shape, batch_size_rcnn)
# forward rcnn network
x_part = self.conv_part(part_features)
x_rpn = self.conv_rpn(rpn_features)
merged_feature = torch.cat((x_rpn.features, x_part.features), dim=1) # (N, C)
shared_feature = spconv.SparseConvTensor(merged_feature, coords, sparse_shape, batch_size_rcnn)
shared_feature = shared_feature.dense().view(batch_size_rcnn, -1, 1)
shared_feature = self.shared_fc_layer(shared_feature)
rcnn_cls = self.cls_layers(shared_feature).transpose(1, 2).contiguous().squeeze(dim=1) # (B, 1 or 2)
rcnn_reg = self.reg_layers(shared_feature).transpose(1, 2).contiguous().squeeze(dim=1) # (B, C)
if not self.training:
batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
batch_size=batch_dict['batch_size'], rois=batch_dict['rois'], cls_preds=rcnn_cls, box_preds=rcnn_reg
)
batch_dict['batch_cls_preds'] = batch_cls_preds
batch_dict['batch_box_preds'] = batch_box_preds
batch_dict['cls_preds_normalized'] = False
else:
targets_dict['rcnn_cls'] = rcnn_cls
targets_dict['rcnn_reg'] = rcnn_reg
self.forward_ret_dict = targets_dict
return batch_dict
import torch.nn as nn
from .roi_head_template import RoIHeadTemplate
from ...utils import common_utils
from ...ops.pointnet2.pointnet2_stack import pointnet2_modules as pointnet2_stack_modules
class PVRCNNHead(RoIHeadTemplate):
def __init__(self, input_channels, model_cfg, num_class=1):
super().__init__(num_class=num_class, model_cfg=model_cfg)
self.model_cfg = model_cfg
mlps = self.model_cfg.ROI_GRID_POOL.MLPS
for k in range(len(mlps)):
mlps[k] = [input_channels] + mlps[k]
self.roi_grid_pool_layer = pointnet2_stack_modules.StackSAModuleMSG(
radii=self.model_cfg.ROI_GRID_POOL.POOL_RADIUS,
nsamples=self.model_cfg.ROI_GRID_POOL.NSAMPLE,
mlps=mlps,
use_xyz=True,
pool_method=self.model_cfg.ROI_GRID_POOL.POOL_METHOD,
)
GRID_SIZE = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
c_out = sum([x[-1] for x in mlps])
pre_channel = GRID_SIZE * GRID_SIZE * GRID_SIZE * c_out
shared_fc_list = []
for k in range(0, self.model_cfg.SHARED_FC.__len__()):
shared_fc_list.extend([
nn.Conv1d(pre_channel, self.model_cfg.SHARED_FC[k], kernel_size=1, bias=False),
nn.BatchNorm1d(self.model_cfg.SHARED_FC[k]),
nn.ReLU()
])
pre_channel = self.model_cfg.SHARED_FC[k]
if k != self.model_cfg.SHARED_FC.__len__() - 1 and self.model_cfg.DP_RATIO > 0:
shared_fc_list.append(nn.Dropout(self.model_cfg.DP_RATIO))
self.shared_fc_layer = nn.Sequential(*shared_fc_list)
self.cls_layers = self.make_fc_layers(
input_channels=pre_channel, output_channels=self.num_class, fc_list=self.model_cfg.CLS_FC
)
self.reg_layers = self.make_fc_layers(
input_channels=pre_channel,
output_channels=self.box_coder.code_size * self.num_class,
fc_list=self.model_cfg.REG_FC
)
self.init_weights(weight_init='xavier')
def init_weights(self, weight_init='xavier'):
if weight_init == 'kaiming':
init_func = nn.init.kaiming_normal_
elif weight_init == 'xavier':
init_func = nn.init.xavier_normal_
elif weight_init == 'normal':
init_func = nn.init.normal_
else:
raise NotImplementedError
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
if weight_init == 'normal':
init_func(m.weight, mean=0, std=0.001)
else:
init_func(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.reg_layers[-1].weight, mean=0, std=0.001)
def roi_grid_pool(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
rois: (B, num_rois, 7 + C)
point_coords: (num_points, 4) [bs_idx, x, y, z]
point_features: (num_points, C)
point_cls_scores: (N1 + N2 + N3 + ..., 1)
point_part_offset: (N1 + N2 + N3 + ..., 3)
Returns:
"""
batch_size = batch_dict['batch_size']
rois = batch_dict['rois']
point_coords = batch_dict['point_coords']
point_features = batch_dict['point_features']
point_features = point_features * batch_dict['point_cls_scores'].view(-1, 1)
global_roi_grid_points, local_roi_grid_points = self.get_global_grid_points_of_roi(
rois, grid_size=self.model_cfg.ROI_GRID_POOL.GRID_SIZE
) # (BxN, 6x6x6, 3)
global_roi_grid_points = global_roi_grid_points.view(batch_size, -1, 3) # (B, Nx6x6x6, 3)
xyz = point_coords[:, 1:4]
xyz_batch_cnt = xyz.new_zeros(batch_size).int()
batch_idx = point_coords[:, 0]
for k in range(batch_size):
xyz_batch_cnt[k] = (batch_idx == k).sum()
new_xyz = global_roi_grid_points.view(-1, 3)
new_xyz_batch_cnt = xyz.new_zeros(batch_size).int().fill_(global_roi_grid_points.shape[1])
pooled_points, pooled_features = self.roi_grid_pool_layer(
xyz=xyz.contiguous(),
xyz_batch_cnt=xyz_batch_cnt,
new_xyz=new_xyz,
new_xyz_batch_cnt=new_xyz_batch_cnt,
features=point_features.contiguous(),
) # (M1 + M2 ..., C)
pooled_features = pooled_features.view(
-1, self.model_cfg.ROI_GRID_POOL.GRID_SIZE ** 3,
pooled_features.shape[-1]
) # (BxN, 6x6x6, C)
return pooled_features
def get_global_grid_points_of_roi(self, rois, grid_size):
rois = rois.view(-1, rois.shape[-1])
batch_size_rcnn = rois.shape[0]
local_roi_grid_points = self.get_dense_grid_points(rois, batch_size_rcnn, grid_size) # (B, 6x6x6, 3)
global_roi_grid_points = common_utils.rotate_points_along_z(
local_roi_grid_points.clone(), rois[:, 6]
).squeeze(dim=1)
global_center = rois[:, 0:3].clone()
global_roi_grid_points += global_center.unsqueeze(dim=1)
return global_roi_grid_points, local_roi_grid_points
@staticmethod
def get_dense_grid_points(rois, batch_size_rcnn, grid_size):
faked_features = rois.new_ones((grid_size, grid_size, grid_size))
dense_idx = faked_features.nonzero() # (N, 3) [x_idx, y_idx, z_idx]
dense_idx = dense_idx.repeat(batch_size_rcnn, 1, 1).float() # (B, 6x6x6, 3)
local_roi_size = rois.view(batch_size_rcnn, -1)[:, 3:6]
roi_grid_points = (dense_idx + 0.5) / grid_size * local_roi_size.unsqueeze(dim=1) \
- (local_roi_size.unsqueeze(dim=1) / 2) # (B, 6x6x6, 3)
return roi_grid_points
def forward(self, batch_dict):
"""
:param input_data: input dict
:return:
"""
targets_dict = self.proposal_layer(
batch_dict, nms_config=self.model_cfg.NMS_CONFIG['TRAIN' if self.training else 'TEST']
)
if self.training:
targets_dict = self.assign_targets(batch_dict)
batch_dict['rois'] = targets_dict['rois']
batch_dict['roi_labels'] = targets_dict['roi_labels']
# RoI aware pooling
pooled_features = self.roi_grid_pool(batch_dict) # (BxN, 6x6x6, C)
grid_size = self.model_cfg.ROI_GRID_POOL.GRID_SIZE
batch_size_rcnn = pooled_features.shape[0]
pooled_features = pooled_features.permute(0, 2, 1).\
contiguous().view(batch_size_rcnn, -1, grid_size, grid_size, grid_size) # (BxN, C, 6, 6, 6)
shared_features = self.shared_fc_layer(pooled_features.view(batch_size_rcnn, -1, 1))
rcnn_cls = self.cls_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B, 1 or 2)
rcnn_reg = self.reg_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B, C)
if not self.training:
batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
batch_size=batch_dict['batch_size'], rois=batch_dict['rois'], cls_preds=rcnn_cls, box_preds=rcnn_reg
)
batch_dict['batch_cls_preds'] = batch_cls_preds
batch_dict['batch_box_preds'] = batch_box_preds
batch_dict['cls_preds_normalized'] = False
else:
targets_dict['rcnn_cls'] = rcnn_cls
targets_dict['rcnn_reg'] = rcnn_reg
self.forward_ret_dict = targets_dict
return batch_dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .target_assigner.proposal_target_layer import ProposalTargetLayer
from ..model_utils.model_nms_utils import class_agnostic_nms
from ...utils import common_utils, loss_utils, box_coder_utils
class RoIHeadTemplate(nn.Module):
def __init__(self, num_class, model_cfg):
super().__init__()
self.model_cfg = model_cfg
self.num_class = num_class
self.box_coder = getattr(box_coder_utils, self.model_cfg.TARGET_CONFIG.BOX_CODER)()
self.proposal_target_layer = ProposalTargetLayer(roi_sampler_cfg=self.model_cfg.TARGET_CONFIG)
self.build_losses(self.model_cfg.LOSS_CONFIG)
self.forward_ret_dict = None
def build_losses(self, losses_cfg):
self.add_module(
'reg_loss_func',
loss_utils.WeightedSmoothL1Loss(code_weights=losses_cfg.LOSS_WEIGHTS['code_weights'])
)
def make_fc_layers(self, input_channels, output_channels, fc_list):
fc_layers = []
pre_channel = input_channels
for k in range(0, fc_list.__len__()):
fc_layers.extend([
nn.Conv1d(pre_channel, fc_list[k], kernel_size=1, bias=False),
nn.BatchNorm1d(fc_list[k]),
nn.ReLU()
])
pre_channel = fc_list[k]
if self.model_cfg.DP_RATIO >= 0 and k == 0:
fc_layers.append(nn.Dropout(self.model_cfg.DP_RATIO))
fc_layers.append(nn.Conv1d(pre_channel, output_channels, kernel_size=1, bias=True))
fc_layers = nn.Sequential(*fc_layers)
return fc_layers
def proposal_layer(self, batch_dict, nms_config):
"""
Args:
batch_dict:
batch_size:
batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1)
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
cls_preds_normalized: indicate whether batch_cls_preds is normalized
batch_index: optional (N1+N2+...)
nms_config:
Returns:
batch_dict:
rois: (B, num_rois, 7+C)
roi_scores: (B, num_rois)
roi_labels: (B, num_rois)
"""
batch_size = batch_dict['batch_size']
batch_box_preds = batch_dict['batch_box_preds']
batch_cls_preds = batch_dict['batch_cls_preds']
rois = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE, batch_box_preds.shape[-1]))
roi_scores = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE))
roi_labels = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE), dtype=torch.long)
for index in range(batch_size):
if batch_dict.get('batch_index', None) is not None:
assert batch_cls_preds.shape.__len__() == 2
batch_mask = (batch_dict['batch_index'] == index)
else:
assert batch_dict['batch_cls_preds'].shape.__len__() == 3
batch_mask = index
box_preds = batch_box_preds[batch_mask]
cls_preds = batch_cls_preds[batch_mask]
cur_roi_scores, cur_roi_labels = torch.max(cls_preds, dim=1)
if nms_config.MULTI_CLASSES_NMS:
raise NotImplementedError
else:
selected, selected_scores = class_agnostic_nms(
box_scores=cur_roi_scores, box_preds=box_preds, nms_config=nms_config
)
rois[index, :len(selected), :] = box_preds[selected]
roi_scores[index, :len(selected)] = cur_roi_scores[selected]
roi_labels[index, :len(selected)] = cur_roi_labels[selected]
batch_dict['rois'] = rois
batch_dict['roi_scores'] = roi_scores
batch_dict['roi_labels'] = roi_labels + 1
batch_dict['has_class_labels'] = True if batch_cls_preds.shape[-1] > 1 else False
return batch_dict
def assign_targets(self, batch_dict):
batch_size = batch_dict['batch_size']
with torch.no_grad():
targets_dict = self.proposal_target_layer.forward(batch_dict)
rois = targets_dict['rois'] # (B, N, 7 + C)
gt_of_rois = targets_dict['gt_of_rois'] # (B, N, 7 + C + 1)
targets_dict['gt_of_rois_src'] = gt_of_rois.clone().detach()
# canonical transformation
roi_center = rois[:, :, 0:3]
roi_ry = rois[:, :, 6] % (2 * np.pi)
gt_of_rois[:, :, 0:3] = gt_of_rois[:, :, 0:3] - roi_center
gt_of_rois[:, :, 6] = gt_of_rois[:, :, 6] - roi_ry
# transfer LiDAR coords to local coords
gt_of_rois = common_utils.rotate_points_along_z(
points=gt_of_rois.view(-1, 1, gt_of_rois.shape[-1]), angle=-roi_ry.view(-1)
).view(batch_size, -1, gt_of_rois.shape[-1])
# flip orientation if rois have opposite orientation
heading_label = gt_of_rois[:, :, 6] % (2 * np.pi) # 0 ~ 2pi
opposite_flag = (heading_label > np.pi * 0.5) & (heading_label < np.pi * 1.5)
heading_label[opposite_flag] = (heading_label[opposite_flag] + np.pi) % (2 * np.pi) # (0 ~ pi/2, 3pi/2 ~ 2pi)
flag = heading_label > np.pi
heading_label[flag] = heading_label[flag] - np.pi * 2 # (-pi/2, pi/2)
heading_label = torch.clamp(heading_label, min=-np.pi / 2, max=np.pi / 2)
gt_of_rois[:, :, 6] = heading_label
targets_dict['gt_of_rois'] = gt_of_rois
return targets_dict
def get_box_reg_layer_loss(self, forward_ret_dict):
loss_cfgs = self.model_cfg.LOSS_CONFIG
code_size = self.box_coder.code_size
reg_valid_mask = forward_ret_dict['reg_valid_mask'].view(-1)
gt_boxes3d_ct = forward_ret_dict['gt_of_rois'][..., 0:code_size]
gt_of_rois_src = forward_ret_dict['gt_of_rois_src'][..., 0:code_size].view(-1, code_size)
rcnn_reg = forward_ret_dict['rcnn_reg'] # (rcnn_batch_size, C)
roi_boxes3d = forward_ret_dict['rois']
rcnn_batch_size = gt_boxes3d_ct.view(-1, code_size).shape[0]
fg_mask = (reg_valid_mask > 0)
fg_sum = fg_mask.long().sum().item()
tb_dict = {}
if loss_cfgs.REG_LOSS == 'smooth-l1':
rois_anchor = roi_boxes3d.clone().detach().view(-1, code_size)
rois_anchor[:, 0:3] = 0
rois_anchor[:, 6] = 0
reg_targets = self.box_coder.encode_torch(
gt_boxes3d_ct.view(rcnn_batch_size, code_size), rois_anchor
)
rcnn_loss_reg = self.reg_loss_func(
rcnn_reg.view(rcnn_batch_size, -1).unsqueeze(dim=0),
reg_targets.unsqueeze(dim=0),
) # [B, M, 7]
rcnn_loss_reg = (rcnn_loss_reg.view(rcnn_batch_size, -1) * fg_mask.unsqueeze(dim=-1).float()).sum() / max(fg_sum, 1)
rcnn_loss_reg = rcnn_loss_reg * loss_cfgs.LOSS_WEIGHTS['rcnn_reg_weight']
tb_dict['rcnn_loss_reg'] = rcnn_loss_reg.item()
if loss_cfgs.CORNER_LOSS_REGULARIZATION and fg_sum > 0:
# TODO: NEED to BE CHECK
fg_rcnn_reg = rcnn_reg.view(rcnn_batch_size, -1)[fg_mask]
fg_roi_boxes3d = roi_boxes3d.view(-1, code_size)[fg_mask]
fg_roi_boxes3d = fg_roi_boxes3d.view(1, -1, code_size)
batch_anchors = fg_roi_boxes3d.clone().detach()
roi_ry = fg_roi_boxes3d[:, :, 6].view(-1)
roi_xyz = fg_roi_boxes3d[:, :, 0:3].view(-1, 3)
batch_anchors[:, :, 0:3] = 0
rcnn_boxes3d = self.box_coder.decode_torch(
fg_rcnn_reg.view(batch_anchors.shape[0], -1, code_size), batch_anchors
).view(-1, code_size)
rcnn_boxes3d = common_utils.rotate_points_along_z(
rcnn_boxes3d.unsqueeze(dim=1), roi_ry
).squeeze(dim=1)
rcnn_boxes3d[:, 0:3] += roi_xyz
loss_corner = loss_utils.get_corner_loss_lidar(
rcnn_boxes3d[:, 0:7],
gt_of_rois_src[fg_mask][:, 0:7]
)
loss_corner = loss_corner.mean()
loss_corner = loss_corner * loss_cfgs.LOSS_WEIGHTS['rcnn_corner_weight']
rcnn_loss_reg += loss_corner
tb_dict['rcnn_loss_corner'] = loss_corner.item()
else:
raise NotImplementedError
return rcnn_loss_reg, tb_dict
def get_box_cls_layer_loss(self, forward_ret_dict):
loss_cfgs = self.model_cfg.LOSS_CONFIG
rcnn_cls = forward_ret_dict['rcnn_cls']
rcnn_cls_labels = forward_ret_dict['rcnn_cls_labels'].view(-1)
if loss_cfgs.CLS_LOSS == 'BinaryCrossEntropy':
rcnn_cls_flat = rcnn_cls.view(-1)
batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rcnn_cls_flat), rcnn_cls_labels.float(), reduction='none')
cls_valid_mask = (rcnn_cls_labels >= 0).float()
rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)
elif loss_cfgs.CLS_LOSS == 'CrossEntropy':
batch_loss_cls = F.cross_entropy(rcnn_cls, rcnn_cls_labels, reduction='none', ignore_index=-1)
cls_valid_mask = (rcnn_cls_labels >= 0).float()
rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)
else:
raise NotImplementedError
rcnn_loss_cls = rcnn_loss_cls * loss_cfgs.LOSS_WEIGHTS['rcnn_cls_weight']
tb_dict = {'rcnn_loss_cls': rcnn_loss_cls.item()}
return rcnn_loss_cls, tb_dict
def get_loss(self, tb_dict=None):
tb_dict = {} if tb_dict is None else tb_dict
rcnn_loss = 0
rcnn_loss_cls, cls_tb_dict = self.get_box_cls_layer_loss(self.forward_ret_dict)
rcnn_loss += rcnn_loss_cls
tb_dict.update(cls_tb_dict)
rcnn_loss_reg, reg_tb_dict = self.get_box_reg_layer_loss(self.forward_ret_dict)
rcnn_loss += rcnn_loss_reg
tb_dict.update(reg_tb_dict)
tb_dict['rcnn_loss'] = rcnn_loss.item()
return rcnn_loss, tb_dict
def generate_predicted_boxes(self, batch_size, rois, cls_preds, box_preds):
"""
Args:
batch_size:
rois: (B, N, 7)
cls_preds: (BN, num_class)
box_preds: (BN, code_size)
Returns:
"""
code_size = self.box_coder.code_size
# batch_cls_preds: (B, N, num_class or 1)
batch_cls_preds = cls_preds.view(batch_size, -1, cls_preds.shape[-1])
batch_box_preds = box_preds.view(batch_size, -1, code_size)
roi_ry = rois[:, :, 6].view(-1)
roi_xyz = rois[:, :, 0:3].view(-1, 3)
local_rois = rois.clone().detach()
local_rois[:, :, 0:3] = 0
batch_box_preds = self.box_coder.decode_torch(batch_box_preds, local_rois).view(-1, code_size)
batch_box_preds = common_utils.rotate_points_along_z(
batch_box_preds.unsqueeze(dim=1), roi_ry
).squeeze(dim=1)
batch_box_preds[:, 0:3] += roi_xyz
batch_box_preds = batch_box_preds.view(batch_size, -1, code_size)
return batch_cls_preds, batch_box_preds
import numpy as np
import torch
import torch.nn as nn
from ....ops.iou3d_nms import iou3d_nms_utils
class ProposalTargetLayer(nn.Module):
def __init__(self, roi_sampler_cfg):
super().__init__()
self.roi_sampler_cfg = roi_sampler_cfg
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
rois: (B, num_rois, 7 + C)
roi_scores: (B, num_rois)
gt_boxes: (B, N, 7 + C + 1)
roi_labels: (B, num_rois)
Returns:
batch_dict:
rois: (B, M, 7 + C)
gt_of_rois: (B, M, 7 + C)
gt_iou_of_rois: (B, M)
roi_scores: (B, M)
roi_labels: (B, M)
reg_valid_mask: (B, M)
rcnn_cls_labels: (B, M)
"""
batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels = self.sample_rois_for_rcnn(
batch_dict=batch_dict
)
# regression valid mask
reg_valid_mask = (batch_roi_ious > self.roi_sampler_cfg.REG_FG_THRESH).long()
# classification label
if self.roi_sampler_cfg.CLS_SCORE_TYPE == 'cls':
batch_cls_labels = (batch_roi_ious > self.roi_sampler_cfg.CLS_FG_THRESH).long()
ignore_mask = (batch_roi_ious > self.roi_sampler_cfg.CLS_BG_THRESH) & \
(batch_roi_ious < self.roi_sampler_cfg.CLS_FG_THRESH)
batch_cls_labels[ignore_mask > 0] = -1
elif self.roi_sampler_cfg.CLS_SCORE_TYPE == 'roi_iou':
iou_bg_thresh = self.roi_sampler_cfg.CLS_BG_THRESH
iou_fg_thresh = self.roi_sampler_cfg.CLS_FG_THRESH
fg_mask = batch_roi_ious > iou_fg_thresh
bg_mask = batch_roi_ious < iou_bg_thresh
interval_mask = (fg_mask == 0) & (bg_mask == 0)
batch_cls_labels = (fg_mask > 0).float()
batch_cls_labels[interval_mask] = \
(batch_roi_ious[interval_mask] - iou_bg_thresh) / (iou_fg_thresh - iou_bg_thresh)
else:
raise NotImplementedError
targets_dict = {'rois': batch_rois, 'gt_of_rois': batch_gt_of_rois, 'gt_iou_of_rois': batch_roi_ious,
'roi_scores': batch_roi_scores, 'roi_labels': batch_roi_labels,
'reg_valid_mask': reg_valid_mask,
'rcnn_cls_labels': batch_cls_labels}
return targets_dict
def sample_rois_for_rcnn(self, batch_dict):
"""
Args:
batch_dict:
batch_size:
rois: (B, num_rois, 7 + C)
roi_scores: (B, num_rois)
gt_boxes: (B, N, 7 + C + 1)
roi_labels: (B, num_rois)
Returns:
"""
batch_size = batch_dict['batch_size']
rois = batch_dict['rois']
roi_scores = batch_dict['roi_scores']
roi_labels = batch_dict['roi_labels']
gt_boxes = batch_dict['gt_boxes']
code_size = rois.shape[-1]
batch_rois = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, code_size)
batch_gt_of_rois = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, code_size + 1)
batch_roi_ious = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE)
batch_roi_scores = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE)
batch_roi_labels = rois.new_zeros((batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE), dtype=torch.long)
for index in range(batch_size):
cur_roi, cur_gt, cur_roi_labels, cur_roi_scores = \
rois[index], gt_boxes[index], roi_labels[index], roi_scores[index]
k = cur_gt.__len__() - 1
while k > 0 and cur_gt[k].sum() == 0:
k -= 1
cur_gt = cur_gt[:k + 1]
cur_gt = cur_gt.new_zeros((1, cur_gt.shape[1])) if len(cur_gt) == 0 else cur_gt
if self.roi_sampler_cfg.get('SAMPLE_ROI_BY_EACH_CLASS', False):
max_overlaps, gt_assignment = self.get_max_iou_with_same_class(
rois=cur_roi, roi_labels=cur_roi_labels,
gt_boxes=cur_gt[:, 0:7], gt_labels=cur_gt[:, -1].long()
)
else:
iou3d = iou3d_nms_utils.boxes_iou3d_gpu(cur_roi, cur_gt[:, 0:7]) # (M, N)
max_overlaps, gt_assignment = torch.max(iou3d, dim=1)
sampled_inds = self.subsample_rois(max_overlaps=max_overlaps)
batch_rois[index] = cur_roi[sampled_inds]
batch_roi_labels[index] = cur_roi_labels[sampled_inds]
batch_roi_ious[index] = max_overlaps[sampled_inds]
batch_roi_scores[index] = cur_roi_scores[sampled_inds]
batch_gt_of_rois[index] = cur_gt[gt_assignment[sampled_inds]]
return batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels
def subsample_rois(self, max_overlaps):
# sample fg, easy_bg, hard_bg
fg_rois_per_image = int(np.round(self.roi_sampler_cfg.FG_RATIO * self.roi_sampler_cfg.ROI_PER_IMAGE))
fg_thresh = min(self.roi_sampler_cfg.REG_FG_THRESH, self.roi_sampler_cfg.CLS_FG_THRESH)
fg_inds = torch.nonzero((max_overlaps >= fg_thresh)).view(-1)
easy_bg_inds = torch.nonzero((max_overlaps < self.roi_sampler_cfg.CLS_BG_THRESH_LO)).view(-1)
hard_bg_inds = torch.nonzero((max_overlaps < self.roi_sampler_cfg.REG_FG_THRESH) &
(max_overlaps >= self.roi_sampler_cfg.CLS_BG_THRESH_LO)).view(-1)
fg_num_rois = fg_inds.numel()
bg_num_rois = hard_bg_inds.numel() + easy_bg_inds.numel()
if fg_num_rois > 0 and bg_num_rois > 0:
# sampling fg
fg_rois_per_this_image = min(fg_rois_per_image, fg_num_rois)
rand_num = torch.from_numpy(np.random.permutation(fg_num_rois)).type_as(max_overlaps).long()
fg_inds = fg_inds[rand_num[:fg_rois_per_this_image]]
# sampling bg
bg_rois_per_this_image = self.roi_sampler_cfg.ROI_PER_IMAGE - fg_rois_per_this_image
bg_inds = self.sample_bg_inds(
hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, self.roi_sampler_cfg.HARD_BG_RATIO
)
elif fg_num_rois > 0 and bg_num_rois == 0:
# sampling fg
rand_num = np.floor(np.random.rand(self.roi_sampler_cfg.ROI_PER_IMAGE) * fg_num_rois)
rand_num = torch.from_numpy(rand_num).type_as(max_overlaps).long()
fg_inds = fg_inds[rand_num]
bg_inds = []
elif bg_num_rois > 0 and fg_num_rois == 0:
# sampling bg
bg_rois_per_this_image = self.roi_sampler_cfg.ROI_PER_IMAGE
bg_inds = self.sample_bg_inds(
hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, self.roi_sampler_cfg.HARD_BG_RATIO
)
else:
print('maxoverlaps:(min=%f, max=%f)' % (max_overlaps.min().item(), max_overlaps.max().item()))
print('ERROR: FG=%d, BG=%d' % (fg_num_rois, bg_num_rois))
raise NotImplementedError
sampled_inds = torch.cat((fg_inds, bg_inds), dim=0)
return sampled_inds
@staticmethod
def sample_bg_inds(hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, hard_bg_ratio):
if hard_bg_inds.numel() > 0 and easy_bg_inds.numel() > 0:
hard_bg_rois_num = min(int(bg_rois_per_this_image * hard_bg_ratio), len(hard_bg_inds))
easy_bg_rois_num = bg_rois_per_this_image - hard_bg_rois_num
# sampling hard bg
rand_idx = torch.randint(low=0, high=hard_bg_inds.numel(), size=(hard_bg_rois_num,)).long()
hard_bg_inds = hard_bg_inds[rand_idx]
# sampling easy bg
rand_idx = torch.randint(low=0, high=easy_bg_inds.numel(), size=(easy_bg_rois_num,)).long()
easy_bg_inds = easy_bg_inds[rand_idx]
bg_inds = torch.cat([hard_bg_inds, easy_bg_inds], dim=0)
elif hard_bg_inds.numel() > 0 and easy_bg_inds.numel() == 0:
hard_bg_rois_num = bg_rois_per_this_image
# sampling hard bg
rand_idx = torch.randint(low=0, high=hard_bg_inds.numel(), size=(hard_bg_rois_num,)).long()
bg_inds = hard_bg_inds[rand_idx]
elif hard_bg_inds.numel() == 0 and easy_bg_inds.numel() > 0:
easy_bg_rois_num = bg_rois_per_this_image
# sampling easy bg
rand_idx = torch.randint(low=0, high=easy_bg_inds.numel(), size=(easy_bg_rois_num,)).long()
bg_inds = easy_bg_inds[rand_idx]
else:
raise NotImplementedError
return bg_inds
@staticmethod
def get_max_iou_with_same_class(rois, roi_labels, gt_boxes, gt_labels):
"""
Args:
rois: (N, 7)
roi_labels: (N)
gt_boxes: (N, )
gt_labels:
Returns:
"""
"""
:param rois: (N, 7)
:param roi_labels: (N)
:param gt_boxes: (N, 8)
:return:
"""
max_overlaps = rois.new_zeros(rois.shape[0])
gt_assignment = roi_labels.new_zeros(roi_labels.shape[0])
for k in range(gt_labels.min().item(), gt_labels.max().item() + 1):
roi_mask = (roi_labels == k)
gt_mask = (gt_labels == k)
if roi_mask.sum() > 0 and gt_mask.sum() > 0:
cur_roi = rois[roi_mask]
cur_gt = gt_boxes[gt_mask]
original_gt_assignment = gt_mask.nonzero().view(-1)
iou3d = iou3d_nms_utils.boxes_iou3d_gpu(cur_roi, cur_gt) # (M, N)
cur_max_overlaps, cur_gt_assignment = torch.max(iou3d, dim=1)
max_overlaps[roi_mask] = cur_max_overlaps
gt_assignment[roi_mask] = original_gt_assignment[cur_gt_assignment]
return max_overlaps, gt_assignment
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment