Commit d490f024 authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

[Refactor] Refactor monoflex head and unittest

parent 98cc28e2
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch import torch
from mmcv.cnn import xavier_init from mmcv.cnn import xavier_init
from mmcv.runner import force_fp32
from mmengine.config import ConfigDict
from mmengine.data import InstanceData
from torch import Tensor
from torch import nn as nn from torch import nn as nn
from mmdet3d.core import Det3DDataSample
from mmdet3d.core.bbox.builder import build_bbox_coder from mmdet3d.core.bbox.builder import build_bbox_coder
from mmdet3d.core.utils import get_ellip_gaussian_2D from mmdet3d.core.utils import get_ellip_gaussian_2D
from mmdet3d.models.builder import build_loss
from mmdet3d.models.model_utils import EdgeFusionModule from mmdet3d.models.model_utils import EdgeFusionModule
from mmdet3d.models.utils import (filter_outside_objs, get_edge_indices, from mmdet3d.models.utils import (filter_outside_objs, get_edge_indices,
get_keypoints, handle_proj_objs) get_keypoints, handle_proj_objs)
...@@ -63,7 +69,7 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -63,7 +69,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
Default: dict(type='L1Loss', loss_weight=0.1). Default: dict(type='L1Loss', loss_weight=0.1).
loss_dims: (dict, optional): Config of dimensions loss. loss_dims: (dict, optional): Config of dimensions loss.
Default: dict(type='L1Loss', loss_weight=0.1). Default: dict(type='L1Loss', loss_weight=0.1).
loss_offsets2d: (dict, optional): Config of offsets2d loss. loss_offsets_2d: (dict, optional): Config of offsets_2d loss.
Default: dict(type='L1Loss', loss_weight=0.1). Default: dict(type='L1Loss', loss_weight=0.1).
loss_direct_depth: (dict, optional): Config of directly regression depth loss. loss_direct_depth: (dict, optional): Config of directly regression depth loss.
Default: dict(type='L1Loss', loss_weight=0.1). Default: dict(type='L1Loss', loss_weight=0.1).
...@@ -81,27 +87,33 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -81,27 +87,33 @@ class MonoFlexHead(AnchorFreeMono3DHead):
""" # noqa: E501 """ # noqa: E501
def __init__(self, def __init__(self,
num_classes, num_classes: int,
in_channels, in_channels: int,
use_edge_fusion, use_edge_fusion: bool,
edge_fusion_inds, edge_fusion_inds: List[Tuple],
edge_heatmap_ratio, edge_heatmap_ratio: float,
filter_outside_objs=True, filter_outside_objs: bool = True,
loss_cls=dict(type='GaussianFocalLoss', loss_weight=1.0), loss_cls: dict = dict(
loss_bbox=dict(type='IoULoss', loss_weight=0.1), type='mmdet.GaussianFocalLoss', loss_weight=1.0),
loss_dir=dict(type='MultiBinLoss', loss_weight=0.1), loss_bbox: dict = dict(type='mmdet.IoULoss', loss_weight=0.1),
loss_keypoints=dict(type='L1Loss', loss_weight=0.1), loss_dir: dict = dict(type='MultiBinLoss', loss_weight=0.1),
loss_dims=dict(type='L1Loss', loss_weight=0.1), loss_keypoints: dict = dict(
loss_offsets2d=dict(type='L1Loss', loss_weight=0.1), type='mmdet.L1Loss', loss_weight=0.1),
loss_direct_depth=dict(type='L1Loss', loss_weight=0.1), loss_dims: dict = dict(type='mmdet.L1Loss', loss_weight=0.1),
loss_keypoints_depth=dict(type='L1Loss', loss_weight=0.1), loss_offsets_2d: dict = dict(
loss_combined_depth=dict(type='L1Loss', loss_weight=0.1), type='mmdet.L1Loss', loss_weight=0.1),
loss_attr=None, loss_direct_depth: dict = dict(
bbox_coder=dict(type='MonoFlexCoder', code_size=7), type='mmdet.L1Loss', loss_weight=0.1),
norm_cfg=dict(type='BN'), loss_keypoints_depth: dict = dict(
init_cfg=None, type='mmdet.L1Loss', loss_weight=0.1),
init_bias=-2.19, loss_combined_depth: dict = dict(
**kwargs): type='mmdet.L1Loss', loss_weight=0.1),
loss_attr: Optional[dict] = None,
bbox_coder: dict = dict(type='MonoFlexCoder', code_size=7),
norm_cfg: Union[ConfigDict, dict] = dict(type='BN'),
init_cfg: Optional[Union[ConfigDict, dict]] = None,
init_bias: float = -2.19,
**kwargs) -> None:
self.use_edge_fusion = use_edge_fusion self.use_edge_fusion = use_edge_fusion
self.edge_fusion_inds = edge_fusion_inds self.edge_fusion_inds = edge_fusion_inds
super().__init__( super().__init__(
...@@ -117,13 +129,13 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -117,13 +129,13 @@ class MonoFlexHead(AnchorFreeMono3DHead):
self.filter_outside_objs = filter_outside_objs self.filter_outside_objs = filter_outside_objs
self.edge_heatmap_ratio = edge_heatmap_ratio self.edge_heatmap_ratio = edge_heatmap_ratio
self.init_bias = init_bias self.init_bias = init_bias
self.loss_dir = build_loss(loss_dir) self.loss_dir = MODELS.build(loss_dir)
self.loss_keypoints = build_loss(loss_keypoints) self.loss_keypoints = MODELS.build(loss_keypoints)
self.loss_dims = build_loss(loss_dims) self.loss_dims = MODELS.build(loss_dims)
self.loss_offsets2d = build_loss(loss_offsets2d) self.loss_offsets_2d = MODELS.build(loss_offsets_2d)
self.loss_direct_depth = build_loss(loss_direct_depth) self.loss_direct_depth = MODELS.build(loss_direct_depth)
self.loss_keypoints_depth = build_loss(loss_keypoints_depth) self.loss_keypoints_depth = MODELS.build(loss_keypoints_depth)
self.loss_combined_depth = build_loss(loss_combined_depth) self.loss_combined_depth = MODELS.build(loss_combined_depth)
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = build_bbox_coder(bbox_coder)
def _init_edge_module(self): def _init_edge_module(self):
...@@ -185,13 +197,15 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -185,13 +197,15 @@ class MonoFlexHead(AnchorFreeMono3DHead):
if self.use_edge_fusion: if self.use_edge_fusion:
self._init_edge_module() self._init_edge_module()
def forward_train(self, x, input_metas, gt_bboxes, gt_labels, gt_bboxes_3d, def forward_train(self,
gt_labels_3d, centers2d, depths, attr_labels, x: List[Tensor],
gt_bboxes_ignore, proposal_cfg, **kwargs): batch_data_samples: List[Det3DDataSample],
proposal_cfg: Optional[ConfigDict] = None,
**kwargs):
""" """
Args: Args:
x (list[Tensor]): Features from FPN. x (list[Tensor]): Features from FPN.
input_metas (list[dict]): Meta information of each image, e.g., batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
gt_bboxes (list[Tensor]): Ground truth bboxes of the image, gt_bboxes (list[Tensor]): Ground truth bboxes of the image,
shape (num_gts, 4). shape (num_gts, 4).
...@@ -201,7 +215,7 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -201,7 +215,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
shape (num_gts, self.bbox_code_size). shape (num_gts, self.bbox_code_size).
gt_labels_3d (list[Tensor]): 3D ground truth labels of each box, gt_labels_3d (list[Tensor]): 3D ground truth labels of each box,
shape (num_gts,). shape (num_gts,).
centers2d (list[Tensor]): Projected 3D center of each box, centers_2d (list[Tensor]): Projected 3D center of each box,
shape (num_gts, 2). shape (num_gts, 2).
depths (list[Tensor]): Depth of projected 3D center of each box, depths (list[Tensor]): Depth of projected 3D center of each box,
shape (num_gts,). shape (num_gts,).
...@@ -216,29 +230,75 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -216,29 +230,75 @@ class MonoFlexHead(AnchorFreeMono3DHead):
losses: (dict[str, Tensor]): A dictionary of loss components. losses: (dict[str, Tensor]): A dictionary of loss components.
proposal_list (list[Tensor]): Proposals of each image. proposal_list (list[Tensor]): Proposals of each image.
""" """
outs = self(x, input_metas) """
if gt_labels is None: Args:
loss_inputs = outs + (gt_bboxes, gt_bboxes_3d, centers2d, depths, x (list[Tensor]): Features from FPN.
attr_labels, input_metas) batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each image and corresponding
annotations.
proposal_cfg (mmengine.Config, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
Returns:
tuple or Tensor: When `proposal_cfg` is None, the detector is a \
normal one-stage detector, The return value is the losses.
- losses: (dict[str, Tensor]): A dictionary of loss components.
When the `proposal_cfg` is not None, the head is used as a
`rpn_head`, the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- results_list (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (:obj:`BaseInstance3DBoxes`): Contains a tensor
with shape (num_instances, C), the last dimension C of a
3D box is (x, y, z, x_size, y_size, z_size, yaw, ...), where
C >= 7. C = 7 for kitti and C = 9 for nuscenes with extra 2
dims of velocity.
"""
batch_gt_instances_3d = []
batch_gt_instances_ignore = []
batch_img_metas = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
if 'ignored_instances' in data_sample:
batch_gt_instances_ignore.append(data_sample.ignored_instances)
else: else:
loss_inputs = outs + (gt_bboxes, gt_labels, gt_bboxes_3d, batch_gt_instances_ignore.append(None)
gt_labels_3d, centers2d, depths, attr_labels,
input_metas) # monoflex head needs img_metas for feature extraction
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) outs = self(x, batch_img_metas)
loss_inputs = outs + (batch_gt_instances_3d, batch_img_metas,
batch_gt_instances_ignore)
losses = self.loss(*loss_inputs)
if proposal_cfg is None: if proposal_cfg is None:
return losses return losses
else: else:
proposal_list = self.get_bboxes( batch_img_metas = [
*outs, input_metas, cfg=proposal_cfg) data_sample.metainfo for data_sample in batch_data_samples
return losses, proposal_list ]
results_list = self.get_results(
*outs, batch_img_metas=batch_img_metas, cfg=proposal_cfg)
return losses, results_list
def forward(self, feats, input_metas): def forward(self, feats: List[Tensor], batch_img_metas: List[dict]):
"""Forward features from the upstream network. """Forward features from the upstream network.
Args: Args:
feats (list[Tensor]): Features from the upstream network, each is feats (list[Tensor]): Features from the upstream network, each is
a 4D-tensor. a 4D-tensor.
input_metas (list[dict]): Meta information of each image, e.g., batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
Returns: Returns:
...@@ -250,21 +310,21 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -250,21 +310,21 @@ class MonoFlexHead(AnchorFreeMono3DHead):
level, each is a 4D-tensor, the channel number is level, each is a 4D-tensor, the channel number is
num_points * bbox_code_size. num_points * bbox_code_size.
""" """
mlvl_input_metas = [input_metas for i in range(len(feats))] mlvl_batch_img_metas = [batch_img_metas for i in range(len(feats))]
return multi_apply(self.forward_single, feats, mlvl_input_metas) return multi_apply(self.forward_single, feats, mlvl_batch_img_metas)
def forward_single(self, x, input_metas): def forward_single(self, x: Tensor, batch_img_metas: List[dict]):
"""Forward features of a single scale level. """Forward features of a single scale level.
Args: Args:
x (Tensor): Feature maps from a specific FPN feature level. x (Tensor): Feature maps from a specific FPN feature level.
input_metas (list[dict]): Meta information of each image, e.g., batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
Returns: Returns:
tuple: Scores for each class, bbox predictions. tuple: Scores for each class, bbox predictions.
""" """
img_h, img_w = input_metas[0]['pad_shape'][:2] img_h, img_w = batch_img_metas[0]['pad_shape'][:2]
batch_size, _, feat_h, feat_w = x.shape batch_size, _, feat_h, feat_w = x.shape
downsample_ratio = img_h / feat_h downsample_ratio = img_h / feat_h
...@@ -275,7 +335,7 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -275,7 +335,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
if self.use_edge_fusion: if self.use_edge_fusion:
# calculate the edge indices for the batch data # calculate the edge indices for the batch data
edge_indices_list = get_edge_indices( edge_indices_list = get_edge_indices(
input_metas, downsample_ratio, device=x.device) batch_img_metas, downsample_ratio, device=x.device)
edge_lens = [ edge_lens = [
edge_indices.shape[0] for edge_indices in edge_indices_list edge_indices.shape[0] for edge_indices in edge_indices_list
] ]
...@@ -313,13 +373,15 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -313,13 +373,15 @@ class MonoFlexHead(AnchorFreeMono3DHead):
return cls_score, bbox_pred return cls_score, bbox_pred
def get_bboxes(self, cls_scores, bbox_preds, input_metas): @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def get_results(self, cls_scores: List[Tensor], bbox_preds: List[Tensor],
batch_img_metas: List[dict]):
"""Generate bboxes from bbox head predictions. """Generate bboxes from bbox head predictions.
Args: Args:
cls_scores (list[Tensor]): Box scores for each scale level. cls_scores (list[Tensor]): Box scores for each scale level.
bbox_preds (list[Tensor]): Box regression for each scale. bbox_preds (list[Tensor]): Box regression for each scale.
input_metas (list[dict]): Meta information of each image, e.g., batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
rescale (bool): If True, return boxes in original image space. rescale (bool): If True, return boxes in original image space.
Returns: Returns:
...@@ -329,18 +391,18 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -329,18 +391,18 @@ class MonoFlexHead(AnchorFreeMono3DHead):
assert len(cls_scores) == len(bbox_preds) == 1 assert len(cls_scores) == len(bbox_preds) == 1
cam2imgs = torch.stack([ cam2imgs = torch.stack([
cls_scores[0].new_tensor(input_meta['cam2img']) cls_scores[0].new_tensor(input_meta['cam2img'])
for input_meta in input_metas for input_meta in batch_img_metas
]) ])
batch_bboxes, batch_scores, batch_topk_labels = self.decode_heatmap( batch_bboxes, batch_scores, batch_topk_labels = self.decode_heatmap(
cls_scores[0], cls_scores[0],
bbox_preds[0], bbox_preds[0],
input_metas, batch_img_metas,
cam2imgs=cam2imgs, cam2imgs=cam2imgs,
topk=100, topk=100,
kernel=3) kernel=3)
result_list = [] result_list = []
for img_id in range(len(input_metas)): for img_id in range(len(batch_img_metas)):
bboxes = batch_bboxes[img_id] bboxes = batch_bboxes[img_id]
scores = batch_scores[img_id] scores = batch_scores[img_id]
...@@ -351,20 +413,29 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -351,20 +413,29 @@ class MonoFlexHead(AnchorFreeMono3DHead):
scores = scores[keep_idx] scores = scores[keep_idx]
labels = labels[keep_idx] labels = labels[keep_idx]
bboxes = input_metas[img_id]['box_type_3d']( bboxes = batch_img_metas[img_id]['box_type_3d'](
bboxes, box_dim=self.bbox_code_size, origin=(0.5, 0.5, 0.5)) bboxes, box_dim=self.bbox_code_size, origin=(0.5, 0.5, 0.5))
attrs = None attrs = None
result_list.append((bboxes, scores, labels, attrs))
results = InstanceData()
results.bboxes_3d = bboxes
results.scores_3d = scores
results.labels_3d = labels
if attrs is not None:
results.attr_labels = attrs
result_list.append(results)
return result_list return result_list
def decode_heatmap(self, def decode_heatmap(self,
cls_score, cls_score: Tensor,
reg_pred, reg_pred: Tensor,
input_metas, batch_img_metas: List[dict],
cam2imgs, cam2imgs: Tensor,
topk=100, topk: int = 100,
kernel=3): kernel: int = 3):
"""Transform outputs into detections raw bbox predictions. """Transform outputs into detections raw bbox predictions.
Args: Args:
...@@ -372,7 +443,7 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -372,7 +443,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
shape (B, num_classes, H, W). shape (B, num_classes, H, W).
reg_pred (Tensor): Box regression map. reg_pred (Tensor): Box regression map.
shape (B, channel, H , W). shape (B, channel, H , W).
input_metas (List[dict]): Meta information of each image, e.g., batch_img_metas (List[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
cam2imgs (Tensor): Camera intrinsic matrix. cam2imgs (Tensor): Camera intrinsic matrix.
shape (N, 4, 4) shape (N, 4, 4)
...@@ -391,7 +462,7 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -391,7 +462,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
- batch_topk_labels (Tensor): Categories of each 3D box. - batch_topk_labels (Tensor): Categories of each 3D box.
shape (B, k) shape (B, k)
""" """
img_h, img_w = input_metas[0]['pad_shape'][:2] img_h, img_w = batch_img_metas[0]['pad_shape'][:2]
batch_size, _, feat_h, feat_w = cls_score.shape batch_size, _, feat_h, feat_w = cls_score.shape
downsample_ratio = img_h / feat_h downsample_ratio = img_h / feat_h
...@@ -404,13 +475,13 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -404,13 +475,13 @@ class MonoFlexHead(AnchorFreeMono3DHead):
regression = transpose_and_gather_feat(reg_pred, batch_index) regression = transpose_and_gather_feat(reg_pred, batch_index)
regression = regression.view(-1, 8) regression = regression.view(-1, 8)
pred_base_centers2d = torch.cat( pred_base_centers_2d = torch.cat(
[topk_xs.view(-1, 1), [topk_xs.view(-1, 1),
topk_ys.view(-1, 1).float()], dim=1) topk_ys.view(-1, 1).float()], dim=1)
preds = self.bbox_coder.decode(regression, batch_topk_labels, preds = self.bbox_coder.decode(regression, batch_topk_labels,
downsample_ratio, cam2imgs) downsample_ratio, cam2imgs)
pred_locations = self.bbox_coder.decode_location( pred_locations = self.bbox_coder.decode_location(
pred_base_centers2d, preds['offsets2d'], preds['combined_depth'], pred_base_centers_2d, preds['offsets_2d'], preds['combined_depth'],
cam2imgs, downsample_ratio) cam2imgs, downsample_ratio)
pred_yaws = self.bbox_coder.decode_orientation( pred_yaws = self.bbox_coder.decode_orientation(
preds['orientations']).unsqueeze(-1) preds['orientations']).unsqueeze(-1)
...@@ -419,8 +490,8 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -419,8 +490,8 @@ class MonoFlexHead(AnchorFreeMono3DHead):
batch_bboxes = batch_bboxes.view(batch_size, -1, self.bbox_code_size) batch_bboxes = batch_bboxes.view(batch_size, -1, self.bbox_code_size)
return batch_bboxes, batch_scores, batch_topk_labels return batch_bboxes, batch_scores, batch_topk_labels
def get_predictions(self, pred_reg, labels3d, centers2d, reg_mask, def get_predictions(self, pred_reg, labels3d, centers_2d, reg_mask,
batch_indices, input_metas, downsample_ratio): batch_indices, batch_img_metas, downsample_ratio):
"""Prepare predictions for computing loss. """Prepare predictions for computing loss.
Args: Args:
...@@ -428,14 +499,14 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -428,14 +499,14 @@ class MonoFlexHead(AnchorFreeMono3DHead):
shape (B, channel, H , W). shape (B, channel, H , W).
labels3d (Tensor): Labels of each 3D box. labels3d (Tensor): Labels of each 3D box.
shape (B * max_objs, ) shape (B * max_objs, )
centers2d (Tensor): Coords of each projected 3D box centers_2d (Tensor): Coords of each projected 3D box
center on image. shape (N, 2) center on image. shape (N, 2)
reg_mask (Tensor): Indexes of the existence of the 3D box. reg_mask (Tensor): Indexes of the existence of the 3D box.
shape (B * max_objs, ) shape (B * max_objs, )
batch_indices (Tenosr): Batch indices of the 3D box. batch_indices (Tenosr): Batch indices of the 3D box.
shape (N, 3) shape (N, 3)
input_metas (list[dict]): Meta information of each image, batch_img_metas (list[dict]): Meta information of each image, e.g.,
e.g., image size, scaling factor, etc. image size, scaling factor, etc.
downsample_ratio (int): The stride of feature map. downsample_ratio (int): The stride of feature map.
Returns: Returns:
...@@ -444,50 +515,41 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -444,50 +515,41 @@ class MonoFlexHead(AnchorFreeMono3DHead):
batch, channel = pred_reg.shape[0], pred_reg.shape[1] batch, channel = pred_reg.shape[0], pred_reg.shape[1]
w = pred_reg.shape[3] w = pred_reg.shape[3]
cam2imgs = torch.stack([ cam2imgs = torch.stack([
centers2d.new_tensor(input_meta['cam2img']) centers_2d.new_tensor(img_meta['cam2img'])
for input_meta in input_metas for img_meta in batch_img_metas
]) ])
# (batch_size, 4, 4) -> (N, 4, 4) # (batch_size, 4, 4) -> (N, 4, 4)
cam2imgs = cam2imgs[batch_indices, :, :] cam2imgs = cam2imgs[batch_indices, :, :]
centers2d_inds = centers2d[:, 1] * w + centers2d[:, 0] centers_2d_inds = centers_2d[:, 1] * w + centers_2d[:, 0]
centers2d_inds = centers2d_inds.view(batch, -1) centers_2d_inds = centers_2d_inds.view(batch, -1)
pred_regression = transpose_and_gather_feat(pred_reg, centers2d_inds) pred_regression = transpose_and_gather_feat(pred_reg, centers_2d_inds)
pred_regression_pois = pred_regression.view(-1, channel)[reg_mask] pred_regression_pois = pred_regression.view(-1, channel)[reg_mask]
preds = self.bbox_coder.decode(pred_regression_pois, labels3d, preds = self.bbox_coder.decode(pred_regression_pois, labels3d,
downsample_ratio, cam2imgs) downsample_ratio, cam2imgs)
return preds return preds
def get_targets(self, gt_bboxes_list, gt_labels_list, gt_bboxes_3d_list, def get_targets(self, batch_gt_instances_3d: List[InstanceData],
gt_labels_3d_list, centers2d_list, depths_list, feat_shape, feat_shape: Tuple[int], batch_img_metas: List[dict]):
img_shape, input_metas):
"""Get training targets for batch images. """Get training targets for batch images.
`` ``
Args: Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
image, shape (num_gt, 4). gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels_list (list[Tensor]): Ground truth labels of each 、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
box, shape (num_gt,). attributes.
gt_bboxes_3d_list (list[:obj:`CameraInstance3DBoxes`]): 3D
Ground truth bboxes of each image,
shape (num_gt, bbox_code_size).
gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of
each box, shape (num_gt,).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D
image, shape (num_gt, 2).
depths_list (list[Tensor]): Depth of projected 3D centers onto 2D
image, each has shape (num_gt, 1).
feat_shape (tuple[int]): Feature map shape with value, feat_shape (tuple[int]): Feature map shape with value,
shape (B, _, H, W). shape (B, _, H, W).
img_shape (tuple[int]): Image shape in [h, w] format. batch_img_metas (list[dict]): Meta information of each image, e.g.,
input_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
Returns: Returns:
tuple[Tensor, dict]: The Tensor value is the targets of tuple[Tensor, dict]: The Tensor value is the targets of
center heatmap, the dict has components below: center heatmap, the dict has components below:
- base_centers2d_target (Tensor): Coords of each projected 3D box - base_centers_2d_target (Tensor): Coords of each projected
center on image. shape (B * max_objs, 2), [dtype: int] 3D box center on image. shape (B * max_objs, 2),
[dtype: int]
- labels3d (Tensor): Labels of each 3D box. - labels3d (Tensor): Labels of each 3D box.
shape (N, ) shape (N, )
- reg_mask (Tensor): Mask of the existence of the 3D box. - reg_mask (Tensor): Mask of the existence of the 3D box.
...@@ -504,14 +566,36 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -504,14 +566,36 @@ class MonoFlexHead(AnchorFreeMono3DHead):
of each 3D box. shape (N, 3) of each 3D box. shape (N, 3)
- orientations_target (Tensor): Orientation (encoded local yaw) - orientations_target (Tensor): Orientation (encoded local yaw)
target of each 3D box. shape (N, ) target of each 3D box. shape (N, )
- offsets2d_target (Tensor): Offsets target of each projected - offsets_2d_target (Tensor): Offsets target of each projected
3D box. shape (N, 2) 3D box. shape (N, 2)
- dimensions_target (Tensor): Dimensions target of each 3D box. - dimensions_target (Tensor): Dimensions target of each 3D box.
shape (N, 3) shape (N, 3)
- downsample_ratio (int): The stride of feature map. - downsample_ratio (int): The stride of feature map.
""" """
img_h, img_w = img_shape[:2] gt_bboxes_list = [
gt_instances_3d.bboxes for gt_instances_3d in batch_gt_instances_3d
]
gt_labels_list = [
gt_instances_3d.labels for gt_instances_3d in batch_gt_instances_3d
]
gt_bboxes_3d_list = [
gt_instances_3d.bboxes_3d
for gt_instances_3d in batch_gt_instances_3d
]
gt_labels_3d_list = [
gt_instances_3d.labels_3d
for gt_instances_3d in batch_gt_instances_3d
]
centers_2d_list = [
gt_instances_3d.centers_2d
for gt_instances_3d in batch_gt_instances_3d
]
depths_list = [
gt_instances_3d.depths for gt_instances_3d in batch_gt_instances_3d
]
img_h, img_w = batch_img_metas[0]['pad_shape'][:2]
batch_size, _, feat_h, feat_w = feat_shape batch_size, _, feat_h, feat_w = feat_shape
width_ratio = float(feat_w / img_w) # 1/4 width_ratio = float(feat_w / img_w) # 1/4
...@@ -523,16 +607,16 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -523,16 +607,16 @@ class MonoFlexHead(AnchorFreeMono3DHead):
if self.filter_outside_objs: if self.filter_outside_objs:
filter_outside_objs(gt_bboxes_list, gt_labels_list, filter_outside_objs(gt_bboxes_list, gt_labels_list,
gt_bboxes_3d_list, gt_labels_3d_list, gt_bboxes_3d_list, gt_labels_3d_list,
centers2d_list, input_metas) centers_2d_list, batch_img_metas)
# transform centers2d to base centers2d for regression and # transform centers_2d to base centers_2d for regression and
# heatmap generation. # heatmap generation.
# centers2d = int(base_centers2d) + offsets2d # centers_2d = int(base_centers_2d) + offsets_2d
base_centers2d_list, offsets2d_list, trunc_mask_list = \ base_centers_2d_list, offsets_2d_list, trunc_mask_list = \
handle_proj_objs(centers2d_list, gt_bboxes_list, input_metas) handle_proj_objs(centers_2d_list, gt_bboxes_list, batch_img_metas)
keypoints2d_list, keypoints_mask_list, keypoints_depth_mask_list = \ keypoints2d_list, keypoints_mask_list, keypoints_depth_mask_list = \
get_keypoints(gt_bboxes_3d_list, centers2d_list, input_metas) get_keypoints(gt_bboxes_3d_list, centers_2d_list, batch_img_metas)
center_heatmap_target = gt_bboxes_list[-1].new_zeros( center_heatmap_target = gt_bboxes_list[-1].new_zeros(
[batch_size, self.num_classes, feat_h, feat_w]) [batch_size, self.num_classes, feat_h, feat_w])
...@@ -542,11 +626,11 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -542,11 +626,11 @@ class MonoFlexHead(AnchorFreeMono3DHead):
gt_bboxes = gt_bboxes_list[batch_id] * width_ratio gt_bboxes = gt_bboxes_list[batch_id] * width_ratio
gt_labels = gt_labels_list[batch_id] gt_labels = gt_labels_list[batch_id]
# project base centers2d from input image to feat map # project base centers_2d from input image to feat map
gt_base_centers2d = base_centers2d_list[batch_id] * width_ratio gt_base_centers_2d = base_centers_2d_list[batch_id] * width_ratio
trunc_masks = trunc_mask_list[batch_id] trunc_masks = trunc_mask_list[batch_id]
for j, base_center2d in enumerate(gt_base_centers2d): for j, base_center2d in enumerate(gt_base_centers_2d):
if trunc_masks[j]: if trunc_masks[j]:
# for outside objects, generate ellipse heatmap # for outside objects, generate ellipse heatmap
base_center2d_x_int, base_center2d_y_int = \ base_center2d_x_int, base_center2d_y_int = \
...@@ -579,40 +663,40 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -579,40 +663,40 @@ class MonoFlexHead(AnchorFreeMono3DHead):
[base_center2d_x_int, base_center2d_y_int], radius) [base_center2d_x_int, base_center2d_y_int], radius)
avg_factor = max(1, center_heatmap_target.eq(1).sum()) avg_factor = max(1, center_heatmap_target.eq(1).sum())
num_ctrs = [centers2d.shape[0] for centers2d in centers2d_list] num_ctrs = [centers_2d.shape[0] for centers_2d in centers_2d_list]
max_objs = max(num_ctrs) max_objs = max(num_ctrs)
batch_indices = [ batch_indices = [
centers2d_list[0].new_full((num_ctrs[i], ), i) centers_2d_list[0].new_full((num_ctrs[i], ), i)
for i in range(batch_size) for i in range(batch_size)
] ]
batch_indices = torch.cat(batch_indices, dim=0) batch_indices = torch.cat(batch_indices, dim=0)
reg_mask = torch.zeros( reg_mask = torch.zeros(
(batch_size, max_objs), (batch_size, max_objs),
dtype=torch.bool).to(base_centers2d_list[0].device) dtype=torch.bool).to(base_centers_2d_list[0].device)
gt_bboxes_3d = input_metas['box_type_3d'].cat(gt_bboxes_3d_list) gt_bboxes_3d = batch_img_metas[0]['box_type_3d'].cat(gt_bboxes_3d_list)
gt_bboxes_3d = gt_bboxes_3d.to(base_centers2d_list[0].device) gt_bboxes_3d = gt_bboxes_3d.to(base_centers_2d_list[0].device)
# encode original local yaw to multibin format # encode original local yaw to multibin format
orienations_target = self.bbox_coder.encode(gt_bboxes_3d) orienations_target = self.bbox_coder.encode(gt_bboxes_3d)
batch_base_centers2d = base_centers2d_list[0].new_zeros( batch_base_centers_2d = base_centers_2d_list[0].new_zeros(
(batch_size, max_objs, 2)) (batch_size, max_objs, 2))
for i in range(batch_size): for i in range(batch_size):
reg_mask[i, :num_ctrs[i]] = 1 reg_mask[i, :num_ctrs[i]] = 1
batch_base_centers2d[i, :num_ctrs[i]] = base_centers2d_list[i] batch_base_centers_2d[i, :num_ctrs[i]] = base_centers_2d_list[i]
flatten_reg_mask = reg_mask.flatten() flatten_reg_mask = reg_mask.flatten()
# transform base centers2d from input scale to output scale # transform base centers_2d from input scale to output scale
batch_base_centers2d = batch_base_centers2d.view(-1, 2) * width_ratio batch_base_centers_2d = batch_base_centers_2d.view(-1, 2) * width_ratio
dimensions_target = gt_bboxes_3d.tensor[:, 3:6] dimensions_target = gt_bboxes_3d.tensor[:, 3:6]
labels_3d = torch.cat(gt_labels_3d_list) labels_3d = torch.cat(gt_labels_3d_list)
keypoints2d_target = torch.cat(keypoints2d_list) keypoints2d_target = torch.cat(keypoints2d_list)
keypoints_mask = torch.cat(keypoints_mask_list) keypoints_mask = torch.cat(keypoints_mask_list)
keypoints_depth_mask = torch.cat(keypoints_depth_mask_list) keypoints_depth_mask = torch.cat(keypoints_depth_mask_list)
offsets2d_target = torch.cat(offsets2d_list) offsets_2d_target = torch.cat(offsets_2d_list)
bboxes2d = torch.cat(gt_bboxes_list) bboxes2d = torch.cat(gt_bboxes_list)
# transform FCOS style bbox into [x1, y1, x2, y2] format. # transform FCOS style bbox into [x1, y1, x2, y2] format.
...@@ -621,7 +705,7 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -621,7 +705,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
depths = torch.cat(depths_list) depths = torch.cat(depths_list)
target_labels = dict( target_labels = dict(
base_centers2d_target=batch_base_centers2d.int(), base_centers_2d_target=batch_base_centers_2d.int(),
labels3d=labels_3d, labels3d=labels_3d,
reg_mask=flatten_reg_mask, reg_mask=flatten_reg_mask,
batch_indices=batch_indices, batch_indices=batch_indices,
...@@ -631,24 +715,18 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -631,24 +715,18 @@ class MonoFlexHead(AnchorFreeMono3DHead):
keypoints_mask=keypoints_mask, keypoints_mask=keypoints_mask,
keypoints_depth_mask=keypoints_depth_mask, keypoints_depth_mask=keypoints_depth_mask,
orienations_target=orienations_target, orienations_target=orienations_target,
offsets2d_target=offsets2d_target, offsets_2d_target=offsets_2d_target,
dimensions_target=dimensions_target, dimensions_target=dimensions_target,
downsample_ratio=1 / width_ratio) downsample_ratio=1 / width_ratio)
return center_heatmap_target, avg_factor, target_labels return center_heatmap_target, avg_factor, target_labels
def loss(self, def loss(self,
cls_scores, cls_scores: List[Tensor],
bbox_preds, bbox_preds: List[Tensor],
gt_bboxes, batch_gt_instances_3d: List[InstanceData],
gt_labels, batch_img_metas: List[dict],
gt_bboxes_3d, batch_gt_instances_ignore: Optional[List[InstanceData]] = None):
gt_labels_3d,
centers2d,
depths,
attr_labels,
input_metas,
gt_bboxes_ignore=None):
"""Compute loss of the head. """Compute loss of the head.
Args: Args:
...@@ -657,48 +735,37 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -657,48 +735,37 @@ class MonoFlexHead(AnchorFreeMono3DHead):
bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel
number is bbox_code_size. number is bbox_code_size.
shape (B, 7, H, W). shape (B, 7, H, W).
gt_bboxes (list[Tensor]): Ground truth bboxes for each image. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels (list[Tensor]): Class indices corresponding to each box. 、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
shape (num_gts, ). attributes.
gt_bboxes_3d (list[:obj:`CameraInstance3DBoxes`]): 3D boxes ground batch_img_metas (list[dict]): Meta information of each image, e.g.,
truth. it is the flipped gt_bboxes
gt_labels_3d (list[Tensor]): Same as gt_labels.
centers2d (list[Tensor]): 2D centers on the image.
shape (num_gts, 2).
depths (list[Tensor]): Depth ground truth.
shape (num_gts, ).
attr_labels (list[Tensor]): Attributes indices of each box.
In kitti it's None.
input_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
boxes can be ignored when computing the loss. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
Default: None. data that is ignored during training and testing.
Defaults to None.
Returns: Returns:
dict[str, Tensor]: A dictionary of loss components. dict[str, Tensor]: A dictionary of loss components.
""" """
assert len(cls_scores) == len(bbox_preds) == 1 assert len(cls_scores) == len(bbox_preds) == 1
assert attr_labels is None assert batch_gt_instances_ignore is None
assert gt_bboxes_ignore is None
center2d_heatmap = cls_scores[0] center2d_heatmap = cls_scores[0]
pred_reg = bbox_preds[0] pred_reg = bbox_preds[0]
center2d_heatmap_target, avg_factor, target_labels = \ center2d_heatmap_target, avg_factor, target_labels = \
self.get_targets(gt_bboxes, gt_labels, gt_bboxes_3d, self.get_targets(batch_gt_instances_3d,
gt_labels_3d, centers2d, depths,
center2d_heatmap.shape, center2d_heatmap.shape,
input_metas[0]['pad_shape'], batch_img_metas)
input_metas)
preds = self.get_predictions( preds = self.get_predictions(
pred_reg=pred_reg, pred_reg=pred_reg,
labels3d=target_labels['labels3d'], labels3d=target_labels['labels3d'],
centers2d=target_labels['base_centers2d_target'], centers_2d=target_labels['base_centers_2d_target'],
reg_mask=target_labels['reg_mask'], reg_mask=target_labels['reg_mask'],
batch_indices=target_labels['batch_indices'], batch_indices=target_labels['batch_indices'],
input_metas=input_metas, batch_img_metas=batch_img_metas,
downsample_ratio=target_labels['downsample_ratio']) downsample_ratio=target_labels['downsample_ratio'])
# heatmap loss # heatmap loss
...@@ -726,8 +793,8 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -726,8 +793,8 @@ class MonoFlexHead(AnchorFreeMono3DHead):
target_labels['dimensions_target']) target_labels['dimensions_target'])
# offsets for center heatmap # offsets for center heatmap
loss_offsets2d = self.loss_offsets2d(preds['offsets2d'], loss_offsets_2d = self.loss_offsets_2d(
target_labels['offsets2d_target']) preds['offsets_2d'], target_labels['offsets_2d_target'])
# directly regressed depth loss with direct depth uncertainty loss # directly regressed depth loss with direct depth uncertainty loss
direct_depth_weights = torch.exp(-preds['direct_depth_uncertainty']) direct_depth_weights = torch.exp(-preds['direct_depth_uncertainty'])
...@@ -764,7 +831,7 @@ class MonoFlexHead(AnchorFreeMono3DHead): ...@@ -764,7 +831,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
loss_keypoints=loss_keypoints, loss_keypoints=loss_keypoints,
loss_dir=loss_dir, loss_dir=loss_dir,
loss_dims=loss_dims, loss_dims=loss_dims,
loss_offsets2d=loss_offsets2d, loss_offsets_2d=loss_offsets_2d,
loss_direct_depth=loss_direct_depth, loss_direct_depth=loss_direct_depth,
loss_keypoints_depth=loss_keypoints_depth, loss_keypoints_depth=loss_keypoints_depth,
loss_combined_depth=loss_combined_depth) loss_combined_depth=loss_combined_depth)
......
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmdet3d.models.dense_heads import MonoFlexHead
class TestMonoFlexHead(TestCase):
def test_monoflex_head_loss(self):
"""Tests MonoFlex head loss and inference."""
input_metas = [dict(img_shape=(110, 110), pad_shape=(128, 128))]
monoflex_head = MonoFlexHead(
num_classes=3,
in_channels=64,
use_edge_fusion=True,
edge_fusion_inds=[(1, 0)],
edge_heatmap_ratio=1 / 8,
stacked_convs=0,
feat_channels=64,
use_direction_classifier=False,
diff_rad_by_sin=False,
pred_attrs=False,
pred_velo=False,
dir_offset=0,
strides=None,
group_reg_dims=((4, ), (2, ), (20, ), (3, ), (3, ), (8, 8), (1, ),
(1, )),
cls_branch=(256, ),
reg_branch=((256, ), (256, ), (256, ), (256, ), (256, ), (256, ),
(256, ), (256, )),
num_attrs=0,
bbox_code_size=7,
dir_branch=(),
attr_branch=(),
bbox_coder=dict(
type='MonoFlexCoder',
depth_mode='exp',
base_depth=(26.494627, 16.05988),
depth_range=[0.1, 100],
combine_depth=True,
uncertainty_range=[-10, 10],
base_dims=((3.8840, 1.5261, 1.6286, 0.4259, 0.1367, 0.1022),
(0.8423, 1.7607, 0.6602, 0.2349, 0.1133, 0.1427),
(1.7635, 1.7372, 0.5968, 0.1766, 0.0948, 0.1242)),
dims_mode='linear',
multibin=True,
num_dir_bins=4,
bin_centers=[0, np.pi / 2, np.pi, -np.pi / 2],
bin_margin=np.pi / 6,
code_size=7),
conv_bias=True,
dcn_on_last_conv=False)
# Monoflex head expects a single level of features per image
feats = [torch.rand([1, 64, 32, 32], dtype=torch.float32)]
# Test forward
cls_score, out_reg = monoflex_head.forward(feats, input_metas)
self.assertEqual(cls_score[0].shape, torch.Size([1, 3, 32, 32]),
'the shape of cls_score should be [1, 3, 32, 32]')
self.assertEqual(out_reg[0].shape, torch.Size([1, 50, 32, 32]),
'the shape of out_reg should be [1, 50, 32, 32]')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment