Commit d490f024 authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

[Refactor] Refactor monoflex head and unittest

parent 98cc28e2
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
from mmcv.cnn import xavier_init
from mmcv.runner import force_fp32
from mmengine.config import ConfigDict
from mmengine.data import InstanceData
from torch import Tensor
from torch import nn as nn
from mmdet3d.core import Det3DDataSample
from mmdet3d.core.bbox.builder import build_bbox_coder
from mmdet3d.core.utils import get_ellip_gaussian_2D
from mmdet3d.models.builder import build_loss
from mmdet3d.models.model_utils import EdgeFusionModule
from mmdet3d.models.utils import (filter_outside_objs, get_edge_indices,
get_keypoints, handle_proj_objs)
......@@ -63,7 +69,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
Default: dict(type='L1Loss', loss_weight=0.1).
loss_dims: (dict, optional): Config of dimensions loss.
Default: dict(type='L1Loss', loss_weight=0.1).
loss_offsets2d: (dict, optional): Config of offsets2d loss.
loss_offsets_2d: (dict, optional): Config of offsets_2d loss.
Default: dict(type='L1Loss', loss_weight=0.1).
loss_direct_depth: (dict, optional): Config of directly regression depth loss.
Default: dict(type='L1Loss', loss_weight=0.1).
......@@ -81,27 +87,33 @@ class MonoFlexHead(AnchorFreeMono3DHead):
""" # noqa: E501
def __init__(self,
num_classes,
in_channels,
use_edge_fusion,
edge_fusion_inds,
edge_heatmap_ratio,
filter_outside_objs=True,
loss_cls=dict(type='GaussianFocalLoss', loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=0.1),
loss_dir=dict(type='MultiBinLoss', loss_weight=0.1),
loss_keypoints=dict(type='L1Loss', loss_weight=0.1),
loss_dims=dict(type='L1Loss', loss_weight=0.1),
loss_offsets2d=dict(type='L1Loss', loss_weight=0.1),
loss_direct_depth=dict(type='L1Loss', loss_weight=0.1),
loss_keypoints_depth=dict(type='L1Loss', loss_weight=0.1),
loss_combined_depth=dict(type='L1Loss', loss_weight=0.1),
loss_attr=None,
bbox_coder=dict(type='MonoFlexCoder', code_size=7),
norm_cfg=dict(type='BN'),
init_cfg=None,
init_bias=-2.19,
**kwargs):
num_classes: int,
in_channels: int,
use_edge_fusion: bool,
edge_fusion_inds: List[Tuple],
edge_heatmap_ratio: float,
filter_outside_objs: bool = True,
loss_cls: dict = dict(
type='mmdet.GaussianFocalLoss', loss_weight=1.0),
loss_bbox: dict = dict(type='mmdet.IoULoss', loss_weight=0.1),
loss_dir: dict = dict(type='MultiBinLoss', loss_weight=0.1),
loss_keypoints: dict = dict(
type='mmdet.L1Loss', loss_weight=0.1),
loss_dims: dict = dict(type='mmdet.L1Loss', loss_weight=0.1),
loss_offsets_2d: dict = dict(
type='mmdet.L1Loss', loss_weight=0.1),
loss_direct_depth: dict = dict(
type='mmdet.L1Loss', loss_weight=0.1),
loss_keypoints_depth: dict = dict(
type='mmdet.L1Loss', loss_weight=0.1),
loss_combined_depth: dict = dict(
type='mmdet.L1Loss', loss_weight=0.1),
loss_attr: Optional[dict] = None,
bbox_coder: dict = dict(type='MonoFlexCoder', code_size=7),
norm_cfg: Union[ConfigDict, dict] = dict(type='BN'),
init_cfg: Optional[Union[ConfigDict, dict]] = None,
init_bias: float = -2.19,
**kwargs) -> None:
self.use_edge_fusion = use_edge_fusion
self.edge_fusion_inds = edge_fusion_inds
super().__init__(
......@@ -117,13 +129,13 @@ class MonoFlexHead(AnchorFreeMono3DHead):
self.filter_outside_objs = filter_outside_objs
self.edge_heatmap_ratio = edge_heatmap_ratio
self.init_bias = init_bias
self.loss_dir = build_loss(loss_dir)
self.loss_keypoints = build_loss(loss_keypoints)
self.loss_dims = build_loss(loss_dims)
self.loss_offsets2d = build_loss(loss_offsets2d)
self.loss_direct_depth = build_loss(loss_direct_depth)
self.loss_keypoints_depth = build_loss(loss_keypoints_depth)
self.loss_combined_depth = build_loss(loss_combined_depth)
self.loss_dir = MODELS.build(loss_dir)
self.loss_keypoints = MODELS.build(loss_keypoints)
self.loss_dims = MODELS.build(loss_dims)
self.loss_offsets_2d = MODELS.build(loss_offsets_2d)
self.loss_direct_depth = MODELS.build(loss_direct_depth)
self.loss_keypoints_depth = MODELS.build(loss_keypoints_depth)
self.loss_combined_depth = MODELS.build(loss_combined_depth)
self.bbox_coder = build_bbox_coder(bbox_coder)
def _init_edge_module(self):
......@@ -185,13 +197,15 @@ class MonoFlexHead(AnchorFreeMono3DHead):
if self.use_edge_fusion:
self._init_edge_module()
def forward_train(self, x, input_metas, gt_bboxes, gt_labels, gt_bboxes_3d,
gt_labels_3d, centers2d, depths, attr_labels,
gt_bboxes_ignore, proposal_cfg, **kwargs):
def forward_train(self,
x: List[Tensor],
batch_data_samples: List[Det3DDataSample],
proposal_cfg: Optional[ConfigDict] = None,
**kwargs):
"""
Args:
x (list[Tensor]): Features from FPN.
input_metas (list[dict]): Meta information of each image, e.g.,
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes (list[Tensor]): Ground truth bboxes of the image,
shape (num_gts, 4).
......@@ -201,7 +215,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
shape (num_gts, self.bbox_code_size).
gt_labels_3d (list[Tensor]): 3D ground truth labels of each box,
shape (num_gts,).
centers2d (list[Tensor]): Projected 3D center of each box,
centers_2d (list[Tensor]): Projected 3D center of each box,
shape (num_gts, 2).
depths (list[Tensor]): Depth of projected 3D center of each box,
shape (num_gts,).
......@@ -216,29 +230,75 @@ class MonoFlexHead(AnchorFreeMono3DHead):
losses: (dict[str, Tensor]): A dictionary of loss components.
proposal_list (list[Tensor]): Proposals of each image.
"""
outs = self(x, input_metas)
if gt_labels is None:
loss_inputs = outs + (gt_bboxes, gt_bboxes_3d, centers2d, depths,
attr_labels, input_metas)
"""
Args:
x (list[Tensor]): Features from FPN.
batch_data_samples (list[:obj:`Det3DDataSample`]): Each item
contains the meta information of each image and corresponding
annotations.
proposal_cfg (mmengine.Config, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
Returns:
tuple or Tensor: When `proposal_cfg` is None, the detector is a \
normal one-stage detector, The return value is the losses.
- losses: (dict[str, Tensor]): A dictionary of loss components.
When the `proposal_cfg` is not None, the head is used as a
`rpn_head`, the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- results_list (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (:obj:`BaseInstance3DBoxes`): Contains a tensor
with shape (num_instances, C), the last dimension C of a
3D box is (x, y, z, x_size, y_size, z_size, yaw, ...), where
C >= 7. C = 7 for kitti and C = 9 for nuscenes with extra 2
dims of velocity.
"""
batch_gt_instances_3d = []
batch_gt_instances_ignore = []
batch_img_metas = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
batch_gt_instances_3d.append(data_sample.gt_instances_3d)
if 'ignored_instances' in data_sample:
batch_gt_instances_ignore.append(data_sample.ignored_instances)
else:
loss_inputs = outs + (gt_bboxes, gt_labels, gt_bboxes_3d,
gt_labels_3d, centers2d, depths, attr_labels,
input_metas)
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
batch_gt_instances_ignore.append(None)
# monoflex head needs img_metas for feature extraction
outs = self(x, batch_img_metas)
loss_inputs = outs + (batch_gt_instances_3d, batch_img_metas,
batch_gt_instances_ignore)
losses = self.loss(*loss_inputs)
if proposal_cfg is None:
return losses
else:
proposal_list = self.get_bboxes(
*outs, input_metas, cfg=proposal_cfg)
return losses, proposal_list
batch_img_metas = [
data_sample.metainfo for data_sample in batch_data_samples
]
results_list = self.get_results(
*outs, batch_img_metas=batch_img_metas, cfg=proposal_cfg)
return losses, results_list
def forward(self, feats, input_metas):
def forward(self, feats: List[Tensor], batch_img_metas: List[dict]):
"""Forward features from the upstream network.
Args:
feats (list[Tensor]): Features from the upstream network, each is
a 4D-tensor.
input_metas (list[dict]): Meta information of each image, e.g.,
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
......@@ -250,21 +310,21 @@ class MonoFlexHead(AnchorFreeMono3DHead):
level, each is a 4D-tensor, the channel number is
num_points * bbox_code_size.
"""
mlvl_input_metas = [input_metas for i in range(len(feats))]
return multi_apply(self.forward_single, feats, mlvl_input_metas)
mlvl_batch_img_metas = [batch_img_metas for i in range(len(feats))]
return multi_apply(self.forward_single, feats, mlvl_batch_img_metas)
def forward_single(self, x, input_metas):
def forward_single(self, x: Tensor, batch_img_metas: List[dict]):
"""Forward features of a single scale level.
Args:
x (Tensor): Feature maps from a specific FPN feature level.
input_metas (list[dict]): Meta information of each image, e.g.,
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
tuple: Scores for each class, bbox predictions.
"""
img_h, img_w = input_metas[0]['pad_shape'][:2]
img_h, img_w = batch_img_metas[0]['pad_shape'][:2]
batch_size, _, feat_h, feat_w = x.shape
downsample_ratio = img_h / feat_h
......@@ -275,7 +335,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
if self.use_edge_fusion:
# calculate the edge indices for the batch data
edge_indices_list = get_edge_indices(
input_metas, downsample_ratio, device=x.device)
batch_img_metas, downsample_ratio, device=x.device)
edge_lens = [
edge_indices.shape[0] for edge_indices in edge_indices_list
]
......@@ -313,13 +373,15 @@ class MonoFlexHead(AnchorFreeMono3DHead):
return cls_score, bbox_pred
def get_bboxes(self, cls_scores, bbox_preds, input_metas):
@force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def get_results(self, cls_scores: List[Tensor], bbox_preds: List[Tensor],
batch_img_metas: List[dict]):
"""Generate bboxes from bbox head predictions.
Args:
cls_scores (list[Tensor]): Box scores for each scale level.
bbox_preds (list[Tensor]): Box regression for each scale.
input_metas (list[dict]): Meta information of each image, e.g.,
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
rescale (bool): If True, return boxes in original image space.
Returns:
......@@ -329,18 +391,18 @@ class MonoFlexHead(AnchorFreeMono3DHead):
assert len(cls_scores) == len(bbox_preds) == 1
cam2imgs = torch.stack([
cls_scores[0].new_tensor(input_meta['cam2img'])
for input_meta in input_metas
for input_meta in batch_img_metas
])
batch_bboxes, batch_scores, batch_topk_labels = self.decode_heatmap(
cls_scores[0],
bbox_preds[0],
input_metas,
batch_img_metas,
cam2imgs=cam2imgs,
topk=100,
kernel=3)
result_list = []
for img_id in range(len(input_metas)):
for img_id in range(len(batch_img_metas)):
bboxes = batch_bboxes[img_id]
scores = batch_scores[img_id]
......@@ -351,20 +413,29 @@ class MonoFlexHead(AnchorFreeMono3DHead):
scores = scores[keep_idx]
labels = labels[keep_idx]
bboxes = input_metas[img_id]['box_type_3d'](
bboxes = batch_img_metas[img_id]['box_type_3d'](
bboxes, box_dim=self.bbox_code_size, origin=(0.5, 0.5, 0.5))
attrs = None
result_list.append((bboxes, scores, labels, attrs))
results = InstanceData()
results.bboxes_3d = bboxes
results.scores_3d = scores
results.labels_3d = labels
if attrs is not None:
results.attr_labels = attrs
result_list.append(results)
return result_list
def decode_heatmap(self,
cls_score,
reg_pred,
input_metas,
cam2imgs,
topk=100,
kernel=3):
cls_score: Tensor,
reg_pred: Tensor,
batch_img_metas: List[dict],
cam2imgs: Tensor,
topk: int = 100,
kernel: int = 3):
"""Transform outputs into detections raw bbox predictions.
Args:
......@@ -372,7 +443,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
shape (B, num_classes, H, W).
reg_pred (Tensor): Box regression map.
shape (B, channel, H , W).
input_metas (List[dict]): Meta information of each image, e.g.,
batch_img_metas (List[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cam2imgs (Tensor): Camera intrinsic matrix.
shape (N, 4, 4)
......@@ -391,7 +462,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
- batch_topk_labels (Tensor): Categories of each 3D box.
shape (B, k)
"""
img_h, img_w = input_metas[0]['pad_shape'][:2]
img_h, img_w = batch_img_metas[0]['pad_shape'][:2]
batch_size, _, feat_h, feat_w = cls_score.shape
downsample_ratio = img_h / feat_h
......@@ -404,13 +475,13 @@ class MonoFlexHead(AnchorFreeMono3DHead):
regression = transpose_and_gather_feat(reg_pred, batch_index)
regression = regression.view(-1, 8)
pred_base_centers2d = torch.cat(
pred_base_centers_2d = torch.cat(
[topk_xs.view(-1, 1),
topk_ys.view(-1, 1).float()], dim=1)
preds = self.bbox_coder.decode(regression, batch_topk_labels,
downsample_ratio, cam2imgs)
pred_locations = self.bbox_coder.decode_location(
pred_base_centers2d, preds['offsets2d'], preds['combined_depth'],
pred_base_centers_2d, preds['offsets_2d'], preds['combined_depth'],
cam2imgs, downsample_ratio)
pred_yaws = self.bbox_coder.decode_orientation(
preds['orientations']).unsqueeze(-1)
......@@ -419,8 +490,8 @@ class MonoFlexHead(AnchorFreeMono3DHead):
batch_bboxes = batch_bboxes.view(batch_size, -1, self.bbox_code_size)
return batch_bboxes, batch_scores, batch_topk_labels
def get_predictions(self, pred_reg, labels3d, centers2d, reg_mask,
batch_indices, input_metas, downsample_ratio):
def get_predictions(self, pred_reg, labels3d, centers_2d, reg_mask,
batch_indices, batch_img_metas, downsample_ratio):
"""Prepare predictions for computing loss.
Args:
......@@ -428,14 +499,14 @@ class MonoFlexHead(AnchorFreeMono3DHead):
shape (B, channel, H , W).
labels3d (Tensor): Labels of each 3D box.
shape (B * max_objs, )
centers2d (Tensor): Coords of each projected 3D box
centers_2d (Tensor): Coords of each projected 3D box
center on image. shape (N, 2)
reg_mask (Tensor): Indexes of the existence of the 3D box.
shape (B * max_objs, )
batch_indices (Tenosr): Batch indices of the 3D box.
shape (N, 3)
input_metas (list[dict]): Meta information of each image,
e.g., image size, scaling factor, etc.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
downsample_ratio (int): The stride of feature map.
Returns:
......@@ -444,50 +515,41 @@ class MonoFlexHead(AnchorFreeMono3DHead):
batch, channel = pred_reg.shape[0], pred_reg.shape[1]
w = pred_reg.shape[3]
cam2imgs = torch.stack([
centers2d.new_tensor(input_meta['cam2img'])
for input_meta in input_metas
centers_2d.new_tensor(img_meta['cam2img'])
for img_meta in batch_img_metas
])
# (batch_size, 4, 4) -> (N, 4, 4)
cam2imgs = cam2imgs[batch_indices, :, :]
centers2d_inds = centers2d[:, 1] * w + centers2d[:, 0]
centers2d_inds = centers2d_inds.view(batch, -1)
pred_regression = transpose_and_gather_feat(pred_reg, centers2d_inds)
centers_2d_inds = centers_2d[:, 1] * w + centers_2d[:, 0]
centers_2d_inds = centers_2d_inds.view(batch, -1)
pred_regression = transpose_and_gather_feat(pred_reg, centers_2d_inds)
pred_regression_pois = pred_regression.view(-1, channel)[reg_mask]
preds = self.bbox_coder.decode(pred_regression_pois, labels3d,
downsample_ratio, cam2imgs)
return preds
def get_targets(self, gt_bboxes_list, gt_labels_list, gt_bboxes_3d_list,
gt_labels_3d_list, centers2d_list, depths_list, feat_shape,
img_shape, input_metas):
def get_targets(self, batch_gt_instances_3d: List[InstanceData],
feat_shape: Tuple[int], batch_img_metas: List[dict]):
"""Get training targets for batch images.
``
Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each
image, shape (num_gt, 4).
gt_labels_list (list[Tensor]): Ground truth labels of each
box, shape (num_gt,).
gt_bboxes_3d_list (list[:obj:`CameraInstance3DBoxes`]): 3D
Ground truth bboxes of each image,
shape (num_gt, bbox_code_size).
gt_labels_3d_list (list[Tensor]): 3D Ground truth labels of
each box, shape (num_gt,).
centers2d_list (list[Tensor]): Projected 3D centers onto 2D
image, shape (num_gt, 2).
depths_list (list[Tensor]): Depth of projected 3D centers onto 2D
image, each has shape (num_gt, 1).
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes``、``labels``
、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
attributes.
feat_shape (tuple[int]): Feature map shape with value,
shape (B, _, H, W).
img_shape (tuple[int]): Image shape in [h, w] format.
input_metas (list[dict]): Meta information of each image, e.g.,
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
tuple[Tensor, dict]: The Tensor value is the targets of
center heatmap, the dict has components below:
- base_centers2d_target (Tensor): Coords of each projected 3D box
center on image. shape (B * max_objs, 2), [dtype: int]
- base_centers_2d_target (Tensor): Coords of each projected
3D box center on image. shape (B * max_objs, 2),
[dtype: int]
- labels3d (Tensor): Labels of each 3D box.
shape (N, )
- reg_mask (Tensor): Mask of the existence of the 3D box.
......@@ -504,14 +566,36 @@ class MonoFlexHead(AnchorFreeMono3DHead):
of each 3D box. shape (N, 3)
- orientations_target (Tensor): Orientation (encoded local yaw)
target of each 3D box. shape (N, )
- offsets2d_target (Tensor): Offsets target of each projected
- offsets_2d_target (Tensor): Offsets target of each projected
3D box. shape (N, 2)
- dimensions_target (Tensor): Dimensions target of each 3D box.
shape (N, 3)
- downsample_ratio (int): The stride of feature map.
"""
img_h, img_w = img_shape[:2]
gt_bboxes_list = [
gt_instances_3d.bboxes for gt_instances_3d in batch_gt_instances_3d
]
gt_labels_list = [
gt_instances_3d.labels for gt_instances_3d in batch_gt_instances_3d
]
gt_bboxes_3d_list = [
gt_instances_3d.bboxes_3d
for gt_instances_3d in batch_gt_instances_3d
]
gt_labels_3d_list = [
gt_instances_3d.labels_3d
for gt_instances_3d in batch_gt_instances_3d
]
centers_2d_list = [
gt_instances_3d.centers_2d
for gt_instances_3d in batch_gt_instances_3d
]
depths_list = [
gt_instances_3d.depths for gt_instances_3d in batch_gt_instances_3d
]
img_h, img_w = batch_img_metas[0]['pad_shape'][:2]
batch_size, _, feat_h, feat_w = feat_shape
width_ratio = float(feat_w / img_w) # 1/4
......@@ -523,16 +607,16 @@ class MonoFlexHead(AnchorFreeMono3DHead):
if self.filter_outside_objs:
filter_outside_objs(gt_bboxes_list, gt_labels_list,
gt_bboxes_3d_list, gt_labels_3d_list,
centers2d_list, input_metas)
centers_2d_list, batch_img_metas)
# transform centers2d to base centers2d for regression and
# transform centers_2d to base centers_2d for regression and
# heatmap generation.
# centers2d = int(base_centers2d) + offsets2d
base_centers2d_list, offsets2d_list, trunc_mask_list = \
handle_proj_objs(centers2d_list, gt_bboxes_list, input_metas)
# centers_2d = int(base_centers_2d) + offsets_2d
base_centers_2d_list, offsets_2d_list, trunc_mask_list = \
handle_proj_objs(centers_2d_list, gt_bboxes_list, batch_img_metas)
keypoints2d_list, keypoints_mask_list, keypoints_depth_mask_list = \
get_keypoints(gt_bboxes_3d_list, centers2d_list, input_metas)
get_keypoints(gt_bboxes_3d_list, centers_2d_list, batch_img_metas)
center_heatmap_target = gt_bboxes_list[-1].new_zeros(
[batch_size, self.num_classes, feat_h, feat_w])
......@@ -542,11 +626,11 @@ class MonoFlexHead(AnchorFreeMono3DHead):
gt_bboxes = gt_bboxes_list[batch_id] * width_ratio
gt_labels = gt_labels_list[batch_id]
# project base centers2d from input image to feat map
gt_base_centers2d = base_centers2d_list[batch_id] * width_ratio
# project base centers_2d from input image to feat map
gt_base_centers_2d = base_centers_2d_list[batch_id] * width_ratio
trunc_masks = trunc_mask_list[batch_id]
for j, base_center2d in enumerate(gt_base_centers2d):
for j, base_center2d in enumerate(gt_base_centers_2d):
if trunc_masks[j]:
# for outside objects, generate ellipse heatmap
base_center2d_x_int, base_center2d_y_int = \
......@@ -579,40 +663,40 @@ class MonoFlexHead(AnchorFreeMono3DHead):
[base_center2d_x_int, base_center2d_y_int], radius)
avg_factor = max(1, center_heatmap_target.eq(1).sum())
num_ctrs = [centers2d.shape[0] for centers2d in centers2d_list]
num_ctrs = [centers_2d.shape[0] for centers_2d in centers_2d_list]
max_objs = max(num_ctrs)
batch_indices = [
centers2d_list[0].new_full((num_ctrs[i], ), i)
centers_2d_list[0].new_full((num_ctrs[i], ), i)
for i in range(batch_size)
]
batch_indices = torch.cat(batch_indices, dim=0)
reg_mask = torch.zeros(
(batch_size, max_objs),
dtype=torch.bool).to(base_centers2d_list[0].device)
gt_bboxes_3d = input_metas['box_type_3d'].cat(gt_bboxes_3d_list)
gt_bboxes_3d = gt_bboxes_3d.to(base_centers2d_list[0].device)
dtype=torch.bool).to(base_centers_2d_list[0].device)
gt_bboxes_3d = batch_img_metas[0]['box_type_3d'].cat(gt_bboxes_3d_list)
gt_bboxes_3d = gt_bboxes_3d.to(base_centers_2d_list[0].device)
# encode original local yaw to multibin format
orienations_target = self.bbox_coder.encode(gt_bboxes_3d)
batch_base_centers2d = base_centers2d_list[0].new_zeros(
batch_base_centers_2d = base_centers_2d_list[0].new_zeros(
(batch_size, max_objs, 2))
for i in range(batch_size):
reg_mask[i, :num_ctrs[i]] = 1
batch_base_centers2d[i, :num_ctrs[i]] = base_centers2d_list[i]
batch_base_centers_2d[i, :num_ctrs[i]] = base_centers_2d_list[i]
flatten_reg_mask = reg_mask.flatten()
# transform base centers2d from input scale to output scale
batch_base_centers2d = batch_base_centers2d.view(-1, 2) * width_ratio
# transform base centers_2d from input scale to output scale
batch_base_centers_2d = batch_base_centers_2d.view(-1, 2) * width_ratio
dimensions_target = gt_bboxes_3d.tensor[:, 3:6]
labels_3d = torch.cat(gt_labels_3d_list)
keypoints2d_target = torch.cat(keypoints2d_list)
keypoints_mask = torch.cat(keypoints_mask_list)
keypoints_depth_mask = torch.cat(keypoints_depth_mask_list)
offsets2d_target = torch.cat(offsets2d_list)
offsets_2d_target = torch.cat(offsets_2d_list)
bboxes2d = torch.cat(gt_bboxes_list)
# transform FCOS style bbox into [x1, y1, x2, y2] format.
......@@ -621,7 +705,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
depths = torch.cat(depths_list)
target_labels = dict(
base_centers2d_target=batch_base_centers2d.int(),
base_centers_2d_target=batch_base_centers_2d.int(),
labels3d=labels_3d,
reg_mask=flatten_reg_mask,
batch_indices=batch_indices,
......@@ -631,24 +715,18 @@ class MonoFlexHead(AnchorFreeMono3DHead):
keypoints_mask=keypoints_mask,
keypoints_depth_mask=keypoints_depth_mask,
orienations_target=orienations_target,
offsets2d_target=offsets2d_target,
offsets_2d_target=offsets_2d_target,
dimensions_target=dimensions_target,
downsample_ratio=1 / width_ratio)
return center_heatmap_target, avg_factor, target_labels
def loss(self,
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
gt_bboxes_3d,
gt_labels_3d,
centers2d,
depths,
attr_labels,
input_metas,
gt_bboxes_ignore=None):
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
batch_gt_instances_3d: List[InstanceData],
batch_img_metas: List[dict],
batch_gt_instances_ignore: Optional[List[InstanceData]] = None):
"""Compute loss of the head.
Args:
......@@ -657,48 +735,37 @@ class MonoFlexHead(AnchorFreeMono3DHead):
bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel
number is bbox_code_size.
shape (B, 7, H, W).
gt_bboxes (list[Tensor]): Ground truth bboxes for each image.
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box.
shape (num_gts, ).
gt_bboxes_3d (list[:obj:`CameraInstance3DBoxes`]): 3D boxes ground
truth. it is the flipped gt_bboxes
gt_labels_3d (list[Tensor]): Same as gt_labels.
centers2d (list[Tensor]): 2D centers on the image.
shape (num_gts, 2).
depths (list[Tensor]): Depth ground truth.
shape (num_gts, ).
attr_labels (list[Tensor]): Attributes indices of each box.
In kitti it's None.
input_metas (list[dict]): Meta information of each image, e.g.,
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instance_3d. It usually includes ``bboxes``、``labels``
、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Default: None.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert len(cls_scores) == len(bbox_preds) == 1
assert attr_labels is None
assert gt_bboxes_ignore is None
assert batch_gt_instances_ignore is None
center2d_heatmap = cls_scores[0]
pred_reg = bbox_preds[0]
center2d_heatmap_target, avg_factor, target_labels = \
self.get_targets(gt_bboxes, gt_labels, gt_bboxes_3d,
gt_labels_3d, centers2d, depths,
self.get_targets(batch_gt_instances_3d,
center2d_heatmap.shape,
input_metas[0]['pad_shape'],
input_metas)
batch_img_metas)
preds = self.get_predictions(
pred_reg=pred_reg,
labels3d=target_labels['labels3d'],
centers2d=target_labels['base_centers2d_target'],
centers_2d=target_labels['base_centers_2d_target'],
reg_mask=target_labels['reg_mask'],
batch_indices=target_labels['batch_indices'],
input_metas=input_metas,
batch_img_metas=batch_img_metas,
downsample_ratio=target_labels['downsample_ratio'])
# heatmap loss
......@@ -726,8 +793,8 @@ class MonoFlexHead(AnchorFreeMono3DHead):
target_labels['dimensions_target'])
# offsets for center heatmap
loss_offsets2d = self.loss_offsets2d(preds['offsets2d'],
target_labels['offsets2d_target'])
loss_offsets_2d = self.loss_offsets_2d(
preds['offsets_2d'], target_labels['offsets_2d_target'])
# directly regressed depth loss with direct depth uncertainty loss
direct_depth_weights = torch.exp(-preds['direct_depth_uncertainty'])
......@@ -764,7 +831,7 @@ class MonoFlexHead(AnchorFreeMono3DHead):
loss_keypoints=loss_keypoints,
loss_dir=loss_dir,
loss_dims=loss_dims,
loss_offsets2d=loss_offsets2d,
loss_offsets_2d=loss_offsets_2d,
loss_direct_depth=loss_direct_depth,
loss_keypoints_depth=loss_keypoints_depth,
loss_combined_depth=loss_combined_depth)
......
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmdet3d.models.dense_heads import MonoFlexHead
class TestMonoFlexHead(TestCase):
def test_monoflex_head_loss(self):
"""Tests MonoFlex head loss and inference."""
input_metas = [dict(img_shape=(110, 110), pad_shape=(128, 128))]
monoflex_head = MonoFlexHead(
num_classes=3,
in_channels=64,
use_edge_fusion=True,
edge_fusion_inds=[(1, 0)],
edge_heatmap_ratio=1 / 8,
stacked_convs=0,
feat_channels=64,
use_direction_classifier=False,
diff_rad_by_sin=False,
pred_attrs=False,
pred_velo=False,
dir_offset=0,
strides=None,
group_reg_dims=((4, ), (2, ), (20, ), (3, ), (3, ), (8, 8), (1, ),
(1, )),
cls_branch=(256, ),
reg_branch=((256, ), (256, ), (256, ), (256, ), (256, ), (256, ),
(256, ), (256, )),
num_attrs=0,
bbox_code_size=7,
dir_branch=(),
attr_branch=(),
bbox_coder=dict(
type='MonoFlexCoder',
depth_mode='exp',
base_depth=(26.494627, 16.05988),
depth_range=[0.1, 100],
combine_depth=True,
uncertainty_range=[-10, 10],
base_dims=((3.8840, 1.5261, 1.6286, 0.4259, 0.1367, 0.1022),
(0.8423, 1.7607, 0.6602, 0.2349, 0.1133, 0.1427),
(1.7635, 1.7372, 0.5968, 0.1766, 0.0948, 0.1242)),
dims_mode='linear',
multibin=True,
num_dir_bins=4,
bin_centers=[0, np.pi / 2, np.pi, -np.pi / 2],
bin_margin=np.pi / 6,
code_size=7),
conv_bias=True,
dcn_on_last_conv=False)
# Monoflex head expects a single level of features per image
feats = [torch.rand([1, 64, 32, 32], dtype=torch.float32)]
# Test forward
cls_score, out_reg = monoflex_head.forward(feats, input_metas)
self.assertEqual(cls_score[0].shape, torch.Size([1, 3, 32, 32]),
'the shape of cls_score should be [1, 3, 32, 32]')
self.assertEqual(out_reg[0].shape, torch.Size([1, 50, 32, 32]),
'the shape of out_reg should be [1, 50, 32, 32]')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment