Commit e2ead7e9 authored by ZCMax's avatar ZCMax Committed by ChaimZhu
Browse files

Refactor SMOKEHEAD

parent 5db1ead3
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch import torch
from mmcv.runner import force_fp32
from mmengine.config import ConfigDict
from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
...@@ -30,8 +35,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -30,8 +35,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
regression heatmap channels. regression heatmap channels.
ori_channel (list[int]): indices of orientation offset pred in ori_channel (list[int]): indices of orientation offset pred in
regression heatmap channels. regression heatmap channels.
bbox_coder (:obj:`CameraInstance3DBoxes`): Bbox coder bbox_coder (dict): Bbox coder for encoding and decoding boxes.
for encoding and decoding boxes.
loss_cls (dict, optional): Config of classification loss. loss_cls (dict, optional): Config of classification loss.
Default: loss_cls=dict(type='GaussionFocalLoss', loss_weight=1.0). Default: loss_cls=dict(type='GaussionFocalLoss', loss_weight=1.0).
loss_bbox (dict, optional): Config of localization loss. loss_bbox (dict, optional): Config of localization loss.
...@@ -47,18 +51,20 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -47,18 +51,20 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
""" # noqa: E501 """ # noqa: E501
def __init__(self, def __init__(self,
num_classes, num_classes: int,
in_channels, in_channels: int,
dim_channel, dim_channel: List[int],
ori_channel, ori_channel: List[int],
bbox_coder, bbox_coder: dict,
loss_cls=dict(type='GaussionFocalLoss', loss_weight=1.0), loss_cls: dict = dict(
loss_bbox=dict(type='L1Loss', loss_weight=0.1), type='GaussionFocalLoss', loss_weight=1.0),
loss_dir=None, loss_bbox: dict = dict(type='L1Loss', loss_weight=0.1),
loss_attr=None, loss_dir: Optional[dict] = None,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), loss_attr: Optional[dict] = None,
init_cfg=None, norm_cfg: dict = dict(
**kwargs): type='GN', num_groups=32, requires_grad=True),
init_cfg: Optional[Union[ConfigDict, dict]] = None,
**kwargs) -> None:
super().__init__( super().__init__(
num_classes, num_classes,
in_channels, in_channels,
...@@ -73,7 +79,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -73,7 +79,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
self.ori_channel = ori_channel self.ori_channel = ori_channel
self.bbox_coder = build_bbox_coder(bbox_coder) self.bbox_coder = build_bbox_coder(bbox_coder)
def forward(self, feats): def forward(self, feats: Tuple[Tensor]):
"""Forward features from the upstream network. """Forward features from the upstream network.
Args: Args:
...@@ -91,7 +97,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -91,7 +97,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
""" """
return multi_apply(self.forward_single, feats) return multi_apply(self.forward_single, feats)
def forward_single(self, x): def forward_single(self, x: Tensor) -> Union[Tensor, Tensor]:
"""Forward features of a single scale level. """Forward features of a single scale level.
Args: Args:
...@@ -112,13 +118,18 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -112,13 +118,18 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
bbox_pred[:, self.ori_channel, ...] = F.normalize(vector_ori) bbox_pred[:, self.ori_channel, ...] = F.normalize(vector_ori)
return cls_score, bbox_pred return cls_score, bbox_pred
def get_bboxes(self, cls_scores, bbox_preds, img_metas, rescale=None): @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
def get_results(self,
cls_scores,
bbox_preds,
batch_img_metas,
rescale=None):
"""Generate bboxes from bbox head predictions. """Generate bboxes from bbox head predictions.
Args: Args:
cls_scores (list[Tensor]): Box scores for each scale level. cls_scores (list[Tensor]): Box scores for each scale level.
bbox_preds (list[Tensor]): Box regression for each scale. bbox_preds (list[Tensor]): Box regression for each scale.
img_metas (list[dict]): Meta information of each image, e.g., batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
rescale (bool): If True, return boxes in original image space. rescale (bool): If True, return boxes in original image space.
...@@ -128,24 +139,24 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -128,24 +139,24 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
""" """
assert len(cls_scores) == len(bbox_preds) == 1 assert len(cls_scores) == len(bbox_preds) == 1
cam2imgs = torch.stack([ cam2imgs = torch.stack([
cls_scores[0].new_tensor(img_meta['cam2img']) cls_scores[0].new_tensor(img_metas['cam2img'])
for img_meta in img_metas for img_metas in batch_img_metas
]) ])
trans_mats = torch.stack([ trans_mats = torch.stack([
cls_scores[0].new_tensor(img_meta['trans_mat']) cls_scores[0].new_tensor(img_metas['trans_mat'])
for img_meta in img_metas for img_metas in batch_img_metas
]) ])
batch_bboxes, batch_scores, batch_topk_labels = self.decode_heatmap( batch_bboxes, batch_scores, batch_topk_labels = self.decode_heatmap(
cls_scores[0], cls_scores[0],
bbox_preds[0], bbox_preds[0],
img_metas, batch_img_metas,
cam2imgs=cam2imgs, cam2imgs=cam2imgs,
trans_mats=trans_mats, trans_mats=trans_mats,
topk=100, topk=100,
kernel=3) kernel=3)
result_list = [] result_list = []
for img_id in range(len(img_metas)): for img_id in range(len(batch_img_metas)):
bboxes = batch_bboxes[img_id] bboxes = batch_bboxes[img_id]
scores = batch_scores[img_id] scores = batch_scores[img_id]
...@@ -156,7 +167,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -156,7 +167,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
scores = scores[keep_idx] scores = scores[keep_idx]
labels = labels[keep_idx] labels = labels[keep_idx]
bboxes = img_metas[img_id]['box_type_3d']( bboxes = batch_img_metas[img_id]['box_type_3d'](
bboxes, box_dim=self.bbox_code_size, origin=(0.5, 0.5, 0.5)) bboxes, box_dim=self.bbox_code_size, origin=(0.5, 0.5, 0.5))
attrs = None attrs = None
result_list.append((bboxes, scores, labels, attrs)) result_list.append((bboxes, scores, labels, attrs))
...@@ -166,7 +177,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -166,7 +177,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
def decode_heatmap(self, def decode_heatmap(self,
cls_score, cls_score,
reg_pred, reg_pred,
img_metas, batch_img_metas,
cam2imgs, cam2imgs,
trans_mats, trans_mats,
topk=100, topk=100,
...@@ -178,7 +189,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -178,7 +189,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
shape (B, num_classes, H, W). shape (B, num_classes, H, W).
reg_pred (Tensor): Box regression map. reg_pred (Tensor): Box regression map.
shape (B, channel, H , W). shape (B, channel, H , W).
img_metas (List[dict]): Meta information of each image, e.g., batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
cam2imgs (Tensor): Camera intrinsic matrixs. cam2imgs (Tensor): Camera intrinsic matrixs.
shape (B, 4, 4) shape (B, 4, 4)
...@@ -199,7 +210,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -199,7 +210,7 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
- batch_topk_labels (Tensor): Categories of each 3D box. - batch_topk_labels (Tensor): Categories of each 3D box.
shape (B, k) shape (B, k)
""" """
img_h, img_w = img_metas[0]['pad_shape'][:2] img_h, img_w = batch_img_metas[0]['pad_shape'][:2]
bs, _, feat_h, feat_w = cls_score.shape bs, _, feat_h, feat_w = cls_score.shape
center_heatmap_pred = get_local_maximum(cls_score, kernel=kernel) center_heatmap_pred = get_local_maximum(cls_score, kernel=kernel)
...@@ -221,14 +232,15 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -221,14 +232,15 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
batch_bboxes = batch_bboxes.view(bs, -1, self.bbox_code_size) batch_bboxes = batch_bboxes.view(bs, -1, self.bbox_code_size)
return batch_bboxes, batch_scores, batch_topk_labels return batch_bboxes, batch_scores, batch_topk_labels
def get_predictions(self, labels3d, centers2d, gt_locations, gt_dimensions, def get_predictions(self, labels_3d, centers_2d, gt_locations,
gt_orientations, indices, img_metas, pred_reg): gt_dimensions, gt_orientations, indices,
batch_img_metas, pred_reg):
"""Prepare predictions for computing loss. """Prepare predictions for computing loss.
Args: Args:
labels3d (Tensor): Labels of each 3D box. labels_3d (Tensor): Labels of each 3D box.
shape (B, max_objs, ) shape (B, max_objs, )
centers2d (Tensor): Coords of each projected 3D box centers_2d (Tensor): Coords of each projected 3D box
center on image. shape (B * max_objs, 2) center on image. shape (B * max_objs, 2)
gt_locations (Tensor): Coords of each 3D box's location. gt_locations (Tensor): Coords of each 3D box's location.
shape (B * max_objs, 3) shape (B * max_objs, 3)
...@@ -238,8 +250,8 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -238,8 +250,8 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
shape (N, 1) shape (N, 1)
indices (Tensor): Indices of the existence of the 3D box. indices (Tensor): Indices of the existence of the 3D box.
shape (B * max_objs, ) shape (B * max_objs, )
img_metas (list[dict]): Meta information of each image, batch_img_metas (list[dict]): Meta information of each image, e.g.,
e.g., image size, scaling factor, etc. image size, scaling factor, etc.
pre_reg (Tensor): Box regression map. pre_reg (Tensor): Box regression map.
shape (B, channel, H , W). shape (B, channel, H , W).
...@@ -255,19 +267,19 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -255,19 +267,19 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
batch, channel = pred_reg.shape[0], pred_reg.shape[1] batch, channel = pred_reg.shape[0], pred_reg.shape[1]
w = pred_reg.shape[3] w = pred_reg.shape[3]
cam2imgs = torch.stack([ cam2imgs = torch.stack([
gt_locations.new_tensor(img_meta['cam2img']) gt_locations.new_tensor(img_metas['cam2img'])
for img_meta in img_metas for img_metas in batch_img_metas
]) ])
trans_mats = torch.stack([ trans_mats = torch.stack([
gt_locations.new_tensor(img_meta['trans_mat']) gt_locations.new_tensor(img_metas['trans_mat'])
for img_meta in img_metas for img_metas in batch_img_metas
]) ])
centers2d_inds = centers2d[:, 1] * w + centers2d[:, 0] centers_2d_inds = centers_2d[:, 1] * w + centers_2d[:, 0]
centers2d_inds = centers2d_inds.view(batch, -1) centers_2d_inds = centers_2d_inds.view(batch, -1)
pred_regression = transpose_and_gather_feat(pred_reg, centers2d_inds) pred_regression = transpose_and_gather_feat(pred_reg, centers_2d_inds)
pred_regression_pois = pred_regression.view(-1, channel) pred_regression_pois = pred_regression.view(-1, channel)
locations, dimensions, orientations = self.bbox_coder.decode( locations, dimensions, orientations = self.bbox_coder.decode(
pred_regression_pois, centers2d, labels3d, cam2imgs, trans_mats, pred_regression_pois, centers_2d, labels_3d, cam2imgs, trans_mats,
gt_locations) gt_locations)
locations, dimensions, orientations = locations[indices], dimensions[ locations, dimensions, orientations = locations[indices], dimensions[
...@@ -281,44 +293,35 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -281,44 +293,35 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
assert len(dimensions) == len(gt_dimensions) assert len(dimensions) == len(gt_dimensions)
assert len(orientations) == len(gt_orientations) assert len(orientations) == len(gt_orientations)
bbox3d_yaws = self.bbox_coder.encode(gt_locations, gt_dimensions, bbox3d_yaws = self.bbox_coder.encode(gt_locations, gt_dimensions,
orientations, img_metas) orientations, batch_img_metas)
bbox3d_dims = self.bbox_coder.encode(gt_locations, dimensions, bbox3d_dims = self.bbox_coder.encode(gt_locations, dimensions,
gt_orientations, img_metas) gt_orientations, batch_img_metas)
bbox3d_locs = self.bbox_coder.encode(locations, gt_dimensions, bbox3d_locs = self.bbox_coder.encode(locations, gt_dimensions,
gt_orientations, img_metas) gt_orientations, batch_img_metas)
pred_bboxes = dict(ori=bbox3d_yaws, dim=bbox3d_dims, loc=bbox3d_locs) pred_bboxes = dict(ori=bbox3d_yaws, dim=bbox3d_dims, loc=bbox3d_locs)
return pred_bboxes return pred_bboxes
def get_targets(self, gt_bboxes, gt_labels, gt_bboxes_3d, gt_labels_3d, def get_targets(self, batch_gt_instances_3d, feat_shape, batch_img_metas):
centers2d, feat_shape, img_shape, img_metas):
"""Get training targets for batch images. """Get training targets for batch images.
Args: Args:
gt_bboxes (list[Tensor]): Ground truth bboxes of each image, batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
shape (num_gt, 4). gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels (list[Tensor]): Ground truth labels of each box, 、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
shape (num_gt,). attributes.
gt_bboxes_3d (list[:obj:`CameraInstance3DBoxes`]): 3D Ground
truth bboxes of each image,
shape (num_gt, bbox_code_size).
gt_labels_3d (list[Tensor]): 3D Ground truth labels of each
box, shape (num_gt,).
centers2d (list[Tensor]): Projected 3D centers onto 2D image,
shape (num_gt, 2).
feat_shape (tuple[int]): Feature map shape with value, feat_shape (tuple[int]): Feature map shape with value,
shape (B, _, H, W). shape (B, _, H, W).
img_shape (tuple[int]): Image shape in [h, w] format. batch_img_metas (list[dict]): Meta information of each image, e.g.,
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
Returns: Returns:
tuple[Tensor, dict]: The Tensor value is the targets of tuple[Tensor, dict]: The Tensor value is the targets of
center heatmap, the dict has components below: center heatmap, the dict has components below:
- gt_centers2d (Tensor): Coords of each projected 3D box - gt_centers_2d (Tensor): Coords of each projected 3D box
center on image. shape (B * max_objs, 2) center on image. shape (B * max_objs, 2)
- gt_labels3d (Tensor): Labels of each 3D box. - gt_labels_3d (Tensor): Labels of each 3D box.
shape (B, max_objs, ) shape (B, max_objs, )
- indices (Tensor): Indices of the existence of the 3D box. - indices (Tensor): Indices of the existence of the 3D box.
shape (B * max_objs, ) shape (B * max_objs, )
...@@ -334,10 +337,30 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -334,10 +337,30 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
shape (N, 8, 3) shape (N, 8, 3)
""" """
gt_bboxes = [
gt_instances_3d.bboxes for gt_instances_3d in batch_gt_instances_3d
]
gt_labels = [
gt_instances_3d.labels for gt_instances_3d in batch_gt_instances_3d
]
gt_bboxes_3d = [
gt_instances_3d.bboxes_3d
for gt_instances_3d in batch_gt_instances_3d
]
gt_labels_3d = [
gt_instances_3d.labels_3d
for gt_instances_3d in batch_gt_instances_3d
]
centers_2d = [
gt_instances_3d.centers_2d
for gt_instances_3d in batch_gt_instances_3d
]
img_shape = batch_img_metas[0]['pad_shape']
reg_mask = torch.stack([ reg_mask = torch.stack([
gt_bboxes[0].new_tensor( gt_bboxes[0].new_tensor(
not img_meta['affine_aug'], dtype=torch.bool) not img_metas['affine_aug'], dtype=torch.bool)
for img_meta in img_metas for img_metas in batch_img_metas
]) ])
img_h, img_w = img_shape[:2] img_h, img_w = img_shape[:2]
...@@ -351,15 +374,15 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -351,15 +374,15 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
center_heatmap_target = gt_bboxes[-1].new_zeros( center_heatmap_target = gt_bboxes[-1].new_zeros(
[bs, self.num_classes, feat_h, feat_w]) [bs, self.num_classes, feat_h, feat_w])
gt_centers2d = centers2d.copy() gt_centers_2d = centers_2d.copy()
for batch_id in range(bs): for batch_id in range(bs):
gt_bbox = gt_bboxes[batch_id] gt_bbox = gt_bboxes[batch_id]
gt_label = gt_labels[batch_id] gt_label = gt_labels[batch_id]
# project centers2d from input image to feat map # project centers_2d from input image to feat map
gt_center2d = gt_centers2d[batch_id] * width_ratio gt_center_2d = gt_centers_2d[batch_id] * width_ratio
for j, center in enumerate(gt_center2d): for j, center in enumerate(gt_center_2d):
center_x_int, center_y_int = center.int() center_x_int, center_y_int = center.int()
scale_box_h = (gt_bbox[j][3] - gt_bbox[j][1]) * height_ratio scale_box_h = (gt_bbox[j][3] - gt_bbox[j][1]) * height_ratio
scale_box_w = (gt_bbox[j][2] - gt_bbox[j][0]) * width_ratio scale_box_w = (gt_bbox[j][2] - gt_bbox[j][0]) * width_ratio
...@@ -371,33 +394,33 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -371,33 +394,33 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
[center_x_int, center_y_int], radius) [center_x_int, center_y_int], radius)
avg_factor = max(1, center_heatmap_target.eq(1).sum()) avg_factor = max(1, center_heatmap_target.eq(1).sum())
num_ctrs = [center2d.shape[0] for center2d in centers2d] num_ctrs = [center_2d.shape[0] for center_2d in centers_2d]
max_objs = max(num_ctrs) max_objs = max(num_ctrs)
reg_inds = torch.cat( reg_inds = torch.cat(
[reg_mask[i].repeat(num_ctrs[i]) for i in range(bs)]) [reg_mask[i].repeat(num_ctrs[i]) for i in range(bs)])
inds = torch.zeros((bs, max_objs), inds = torch.zeros((bs, max_objs),
dtype=torch.bool).to(centers2d[0].device) dtype=torch.bool).to(centers_2d[0].device)
# put gt 3d bboxes to gpu # put gt 3d bboxes to gpu
gt_bboxes_3d = [ gt_bboxes_3d = [
gt_bbox_3d.to(centers2d[0].device) for gt_bbox_3d in gt_bboxes_3d gt_bbox_3d.to(centers_2d[0].device) for gt_bbox_3d in gt_bboxes_3d
] ]
batch_centers2d = centers2d[0].new_zeros((bs, max_objs, 2)) batch_centers_2d = centers_2d[0].new_zeros((bs, max_objs, 2))
batch_labels_3d = gt_labels_3d[0].new_zeros((bs, max_objs)) batch_labels_3d = gt_labels_3d[0].new_zeros((bs, max_objs))
batch_gt_locations = \ batch_gt_locations = \
gt_bboxes_3d[0].tensor.new_zeros((bs, max_objs, 3)) gt_bboxes_3d[0].tensor.new_zeros((bs, max_objs, 3))
for i in range(bs): for i in range(bs):
inds[i, :num_ctrs[i]] = 1 inds[i, :num_ctrs[i]] = 1
batch_centers2d[i, :num_ctrs[i]] = centers2d[i] batch_centers_2d[i, :num_ctrs[i]] = centers_2d[i]
batch_labels_3d[i, :num_ctrs[i]] = gt_labels_3d[i] batch_labels_3d[i, :num_ctrs[i]] = gt_labels_3d[i]
batch_gt_locations[i, :num_ctrs[i]] = \ batch_gt_locations[i, :num_ctrs[i]] = \
gt_bboxes_3d[i].tensor[:, :3] gt_bboxes_3d[i].tensor[:, :3]
inds = inds.flatten() inds = inds.flatten()
batch_centers2d = batch_centers2d.view(-1, 2) * width_ratio batch_centers_2d = batch_centers_2d.view(-1, 2) * width_ratio
batch_gt_locations = batch_gt_locations.view(-1, 3) batch_gt_locations = batch_gt_locations.view(-1, 3)
# filter the empty image, without gt_bboxes_3d # filter the empty image, without gt_bboxes_3d
...@@ -416,8 +439,8 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -416,8 +439,8 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
[gt_bbox_3d.corners for gt_bbox_3d in gt_bboxes_3d]) [gt_bbox_3d.corners for gt_bbox_3d in gt_bboxes_3d])
target_labels = dict( target_labels = dict(
gt_centers2d=batch_centers2d.long(), gt_centers_2d=batch_centers_2d.long(),
gt_labels3d=batch_labels_3d, gt_labels_3d=batch_labels_3d,
indices=inds, indices=inds,
reg_indices=reg_inds, reg_indices=reg_inds,
gt_locs=batch_gt_locations, gt_locs=batch_gt_locations,
...@@ -430,15 +453,9 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -430,15 +453,9 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
def loss(self, def loss(self,
cls_scores, cls_scores,
bbox_preds, bbox_preds,
gt_bboxes, batch_gt_instances_3d,
gt_labels, batch_img_metas,
gt_bboxes_3d, batch_gt_instances_ignore=None):
gt_labels_3d,
centers2d,
depths,
attr_labels,
img_metas,
gt_bboxes_ignore=None):
"""Compute loss of the head. """Compute loss of the head.
Args: Args:
...@@ -447,53 +464,42 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead): ...@@ -447,53 +464,42 @@ class SMOKEMono3DHead(AnchorFreeMono3DHead):
bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel bbox_preds (list[Tensor]): Box dims is a 4D-tensor, the channel
number is bbox_code_size. number is bbox_code_size.
shape (B, 7, H, W). shape (B, 7, H, W).
gt_bboxes (list[Tensor]): Ground truth bboxes for each image. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. gt_instance_3d. It usually includes ``bboxes``、``labels``
gt_labels (list[Tensor]): Class indices corresponding to each box. 、``bboxes_3d``、``labels_3d``、``depths``、``centers_2d`` and
shape (num_gts, ). attributes.
gt_bboxes_3d (list[:obj:`CameraInstance3DBoxes`]): 3D boxes ground batch_img_metas (list[dict]): Meta information of each image, e.g.,
truth. it is the flipped gt_bboxes
gt_labels_3d (list[Tensor]): Same as gt_labels.
centers2d (list[Tensor]): 2D centers on the image.
shape (num_gts, 2).
depths (list[Tensor]): Depth ground truth.
shape (num_gts, ).
attr_labels (list[Tensor]): Attributes indices of each box.
In kitti it's None.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc. image size, scaling factor, etc.
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
boxes can be ignored when computing the loss. Batch of gt_instances_ignore. It includes ``bboxes`` attribute
Default: None. data that is ignored during training and testing.
Defaults to None.
Returns: Returns:
dict[str, Tensor]: A dictionary of loss components. dict[str, Tensor]: A dictionary of loss components.
""" """
assert len(cls_scores) == len(bbox_preds) == 1 assert len(cls_scores) == len(bbox_preds) == 1
assert attr_labels is None assert batch_gt_instances_ignore is None
assert gt_bboxes_ignore is None center_2d_heatmap = cls_scores[0]
center2d_heatmap = cls_scores[0]
pred_reg = bbox_preds[0] pred_reg = bbox_preds[0]
center2d_heatmap_target, avg_factor, target_labels = \ center_2d_heatmap_target, avg_factor, target_labels = \
self.get_targets(gt_bboxes, gt_labels, gt_bboxes_3d, self.get_targets(batch_gt_instances_3d,
gt_labels_3d, centers2d, center_2d_heatmap.shape,
center2d_heatmap.shape, batch_img_metas)
img_metas[0]['pad_shape'],
img_metas)
pred_bboxes = self.get_predictions( pred_bboxes = self.get_predictions(
labels3d=target_labels['gt_labels3d'], labels_3d=target_labels['gt_labels_3d'],
centers2d=target_labels['gt_centers2d'], centers_2d=target_labels['gt_centers_2d'],
gt_locations=target_labels['gt_locs'], gt_locations=target_labels['gt_locs'],
gt_dimensions=target_labels['gt_dims'], gt_dimensions=target_labels['gt_dims'],
gt_orientations=target_labels['gt_yaws'], gt_orientations=target_labels['gt_yaws'],
indices=target_labels['indices'], indices=target_labels['indices'],
img_metas=img_metas, batch_img_metas=batch_img_metas,
pred_reg=pred_reg) pred_reg=pred_reg)
loss_cls = self.loss_cls( loss_cls = self.loss_cls(
center2d_heatmap, center2d_heatmap_target, avg_factor=avg_factor) center_2d_heatmap, center_2d_heatmap_target, avg_factor=avg_factor)
reg_inds = target_labels['reg_indices'] reg_inds = target_labels['reg_indices']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment