# Copyright (c) Phigent Robotics. All rights reserved. import torch import torch.nn.functional as F from mmcv.runner import force_fp32 from mmdet3d.models import DETECTORS from mmdet3d.models import builder from .bevdet4d import BEVDet4D @DETECTORS.register_module() class BEVDepth4D(BEVDet4D): def forward_train(self, points=None, img_metas=None, gt_bboxes_3d=None, gt_labels_3d=None, gt_labels=None, gt_bboxes=None, img_inputs=None, proposals=None, gt_bboxes_ignore=None, **kwargs): """Forward training function. Args: points (list[torch.Tensor], optional): Points of each sample. Defaults to None. img_metas (list[dict], optional): Meta information of each sample. Defaults to None. gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional): Ground truth 3D boxes. Defaults to None. gt_labels_3d (list[torch.Tensor], optional): Ground truth labels of 3D boxes. Defaults to None. gt_labels (list[torch.Tensor], optional): Ground truth labels of 2D boxes in images. Defaults to None. gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in images. Defaults to None. img (torch.Tensor optional): Images of each sample with shape (N, C, H, W). Defaults to None. proposals ([list[torch.Tensor], optional): Predicted proposals used for training Fast RCNN. Defaults to None. gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth 2D boxes in images to be ignored. Defaults to None. Returns: dict: Losses of different branches. """ img_feats, pts_feats, depth = self.extract_feat( points, img_inputs=img_inputs, img_metas=img_metas, **kwargs) gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W) loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth) losses = dict(loss_depth=loss_depth) losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d, gt_labels_3d, img_metas, gt_bboxes_ignore) losses.update(losses_pts) return losses