Commit d2b71343 authored by 雍大凯's avatar 雍大凯
Browse files

add code

parent 69e57885
# Copyright (c) Phigent Robotics. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet3d.models import DETECTORS
from .bevdet import BEVDet
from mmdet3d.models import builder
@DETECTORS.register_module()
class BEVDepth(BEVDet):
def __init__(self, img_backbone, img_neck, img_view_transformer, img_bev_encoder_backbone, img_bev_encoder_neck,
pts_bbox_head=None, **kwargs):
super(BEVDepth, self).__init__(img_backbone=img_backbone,
img_neck=img_neck,
img_view_transformer=img_view_transformer,
img_bev_encoder_backbone=img_bev_encoder_backbone,
img_bev_encoder_neck=img_bev_encoder_neck,
pts_bbox_head=pts_bbox_head
)
def image_encoder(self, img, stereo=False):
"""
Args:
img: (B, N, 3, H, W)
stereo: bool
Returns:
x: (B, N, C, fH, fW)
stereo_feat: (B*N, C_stereo, fH_stereo, fW_stereo) / None
"""
imgs = img
B, N, C, imH, imW = imgs.shape
imgs = imgs.view(B * N, C, imH, imW)
x = self.img_backbone(imgs)
stereo_feat = None
if stereo:
stereo_feat = x[0]
x = x[1:]
if self.with_img_neck:
x = self.img_neck(x)
if type(x) in [list, tuple]:
x = x[0]
_, output_dim, ouput_H, output_W = x.shape
x = x.view(B, N, output_dim, ouput_H, output_W)
return x, stereo_feat
@force_fp32()
def bev_encoder(self, x):
"""
Args:
x: (B, C, Dy, Dx)
Returns:
x: (B, C', 2*Dy, 2*Dx)
"""
x = self.img_bev_encoder_backbone(x)
x = self.img_bev_encoder_neck(x)
if type(x) in [list, tuple]:
x = x[0]
return x
def prepare_inputs(self, inputs):
# split the inputs into each frame
assert len(inputs) == 7
B, N, C, H, W = inputs[0].shape
imgs, sensor2egos, ego2globals, intrins, post_rots, post_trans, bda = \
inputs
sensor2egos = sensor2egos.view(B, N, 4, 4)
ego2globals = ego2globals.view(B, N, 4, 4)
# calculate the transformation from adj sensor to key ego
keyego2global = ego2globals[:, 0, ...].unsqueeze(1) # (B, 1, 4, 4)
global2keyego = torch.inverse(keyego2global.double()) # (B, 1, 4, 4)
sensor2keyegos = \
global2keyego @ ego2globals.double() @ sensor2egos.double() # (B, N_views, 4, 4)
sensor2keyegos = sensor2keyegos.float()
return [imgs, sensor2keyegos, ego2globals, intrins,
post_rots, post_trans, bda]
def extract_img_feat(self, img_inputs, img_metas, **kwargs):
""" Extract features of images.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
Returns:
x: [(B, C', H', W'), ]
depth: (B*N, D, fH, fW)
"""
imgs, sensor2keyegos, ego2globals, intrins, post_rots, post_trans, bda = self.prepare_inputs(img_inputs)
x, _ = self.image_encoder(imgs) # x: (B, N, C, fH, fW)
mlp_input = self.img_view_transformer.get_mlp_input(
sensor2keyegos, ego2globals, intrins, post_rots, post_trans, bda) # (B, N_views, 27)
x, depth = self.img_view_transformer([x, sensor2keyegos, ego2globals, intrins, post_rots,
post_trans, bda, mlp_input])
# x: (B, C, Dy, Dx)
# depth: (B*N, D, fH, fW)
x = self.bev_encoder(x)
return [x], depth
def extract_feat(self, points, img_inputs, img_metas, **kwargs):
"""Extract features from images and points."""
"""
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
"""
img_feats, depth = self.extract_img_feat(img_inputs, img_metas, **kwargs)
pts_feats = None
return img_feats, pts_feats, depth
def forward_train(self,
points=None,
img_inputs=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
img_metas=None,
gt_bboxes=None,
gt_labels=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W) # N_views = 6 * (N_history + 1)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses = dict(loss_depth=loss_depth)
losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore)
losses.update(losses_pts)
return losses
def forward_test(self,
points=None,
img_inputs=None,
img_metas=None,
**kwargs):
"""
Args:
points (list[torch.Tensor]): the outer list indicates test-time
augmentations and inner torch.Tensor should have a shape NxC,
which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
img (list[torch.Tensor], optional): the outer
list indicates test-time augmentations and inner
torch.Tensor should have a shape NxCxHxW, which contains
all images in the batch. Defaults to None.
"""
for var, name in [(img_inputs, 'img_inputs'),
(img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(img_inputs)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(img_inputs), len(img_metas)))
if not isinstance(img_inputs[0][0], list):
img_inputs = [img_inputs] if img_inputs is None else img_inputs
points = [points] if points is None else points
return self.simple_test(points[0], img_metas[0], img_inputs[0],
**kwargs)
else:
return self.aug_test(None, img_metas[0], img_inputs[0], **kwargs)
def aug_test(self, points, img_metas, img=None, rescale=False):
"""Test function without augmentaiton."""
assert False
def simple_test(self,
points,
img_metas,
img_inputs=None,
rescale=False,
**kwargs):
"""Test function without augmentaiton.
Returns:
bbox_list: List[dict0, dict1, ...] len = bs
dict: {
'pts_bbox': dict: {
'boxes_3d': (N, 9)
'scores_3d': (N, )
'labels_3d': (N, )
}
}
"""
img_feats, _, _ = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
bbox_list = [dict() for _ in range(len(img_metas))]
bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale)
# bbox_pts: List[dict0, dict1, ...], len = batch_size
# dict: {
# 'boxes_3d': (N, 9)
# 'scores_3d': (N, )
# 'labels_3d': (N, )
# }
for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
result_dict['pts_bbox'] = pts_bbox
return bbox_list
def forward_dummy(self,
points=None,
img_metas=None,
img_inputs=None,
**kwargs):
img_feats, _, _ = self.extract_feat(
points, img=img_inputs, img_metas=img_metas, **kwargs)
assert self.with_pts_bbox
outs = self.pts_bbox_head(img_feats)
return outs
\ No newline at end of file
# Copyright (c) Phigent Robotics. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet3d.models import DETECTORS
from mmdet3d.models import builder
from .bevdet4d import BEVDet4D
@DETECTORS.register_module()
class BEVDepth4D(BEVDet4D):
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses = dict(loss_depth=loss_depth)
losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore)
losses.update(losses_pts)
return losses
\ No newline at end of file
# Copyright (c) Phigent Robotics. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet3d.models import DETECTORS
from mmdet3d.models import CenterPoint
from mmdet3d.models import builder
@DETECTORS.register_module()
class BEVDet(CenterPoint):
def __init__(self, img_backbone, img_neck, img_view_transformer, img_bev_encoder_backbone, img_bev_encoder_neck,
pts_bbox_head=None, **kwargs):
super(BEVDet, self).__init__(img_backbone=img_backbone, img_neck=img_neck, pts_bbox_head=pts_bbox_head,
**kwargs)
self.img_view_transformer = builder.build_neck(img_view_transformer)
self.img_bev_encoder_backbone = builder.build_backbone(img_bev_encoder_backbone)
self.img_bev_encoder_neck = builder.build_neck(img_bev_encoder_neck)
@torch.compile
def image_encoder(self, img, stereo=False):
"""
Args:
img: (B, N, 3, H, W)
stereo: bool
Returns:
x: (B, N, C, fH, fW)
stereo_feat: (B*N, C_stereo, fH_stereo, fW_stereo) / None
"""
imgs = img
B, N, C, imH, imW = imgs.shape
imgs = imgs.view(B * N, C, imH, imW)
x = self.img_backbone(imgs)
stereo_feat = None
if stereo:
stereo_feat = x[0]
x = x[1:]
if self.with_img_neck:
x = self.img_neck(x)
if type(x) in [list, tuple]:
x = x[0]
_, output_dim, ouput_H, output_W = x.shape
x = x.view(B, N, output_dim, ouput_H, output_W)
return x, stereo_feat
@torch.compile
@force_fp32()
def bev_encoder(self, x):
"""
Args:
x: (B, C, Dy, Dx)
Returns:
x: (B, C', 2*Dy, 2*Dx)
"""
x = self.img_bev_encoder_backbone(x)
x = self.img_bev_encoder_neck(x)
if type(x) in [list, tuple]:
x = x[0]
return x
@torch.compile
def prepare_inputs(self, inputs):
# split the inputs into each frame
assert len(inputs) == 7
B, N, C, H, W = inputs[0].shape
imgs, sensor2egos, ego2globals, intrins, post_rots, post_trans, bda = \
inputs
sensor2egos = sensor2egos.view(B, N, 4, 4)
ego2globals = ego2globals.view(B, N, 4, 4)
# calculate the transformation from adj sensor to key ego
keyego2global = ego2globals[:, 0, ...].unsqueeze(1) # (B, 1, 4, 4)
global2keyego = torch.inverse(keyego2global.double()) # (B, 1, 4, 4)
sensor2keyegos = \
global2keyego @ ego2globals.double() @ sensor2egos.double() # (B, N_views, 4, 4)
sensor2keyegos = sensor2keyegos.float()
return [imgs, sensor2keyegos, ego2globals, intrins,
post_rots, post_trans, bda]
def extract_img_feat(self, img_inputs, img_metas, **kwargs):
""" Extract features of images.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
Returns:
x: [(B, C', H', W'), ]
depth: (B*N, D, fH, fW)
"""
img_inputs = self.prepare_inputs(img_inputs)
x, _ = self.image_encoder(img_inputs[0]) # x: (B, N, C, fH, fW)
x, depth = self.img_view_transformer([x] + img_inputs[1:7])
# x: (B, C, Dy, Dx)
# depth: (B*N, D, fH, fW)
x = self.bev_encoder(x)
return [x], depth
def extract_feat(self, points, img_inputs, img_metas, **kwargs):
"""Extract features from images and points."""
"""
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
"""
img_feats, depth = self.extract_img_feat(img_inputs, img_metas, **kwargs)
pts_feats = None
return img_feats, pts_feats, depth
def forward_train(self,
points=None,
img_inputs=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
img_metas=None,
gt_bboxes=None,
gt_labels=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W) # N_views = 6 * (N_history + 1)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats, pts_feats, _ = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
losses = dict()
losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore)
losses.update(losses_pts)
return losses
def forward_test(self,
points=None,
img_inputs=None,
img_metas=None,
**kwargs):
"""
Args:
points (list[torch.Tensor]): the outer list indicates test-time
augmentations and inner torch.Tensor should have a shape NxC,
which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
img (list[torch.Tensor], optional): the outer
list indicates test-time augmentations and inner
torch.Tensor should have a shape NxCxHxW, which contains
all images in the batch. Defaults to None.
"""
for var, name in [(img_inputs, 'img_inputs'),
(img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(img_inputs)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(img_inputs), len(img_metas)))
if not isinstance(img_inputs[0][0], list):
img_inputs = [img_inputs] if img_inputs is None else img_inputs
points = [points] if points is None else points
return self.simple_test(points[0], img_metas[0], img_inputs[0],
**kwargs)
else:
return self.aug_test(None, img_metas[0], img_inputs[0], **kwargs)
def aug_test(self, points, img_metas, img=None, rescale=False):
"""Test function without augmentaiton."""
assert False
def simple_test(self,
points,
img_metas,
img_inputs=None,
rescale=False,
**kwargs):
"""Test function without augmentaiton.
Returns:
bbox_list: List[dict0, dict1, ...] len = bs
dict: {
'pts_bbox': dict: {
'boxes_3d': (N, 9)
'scores_3d': (N, )
'labels_3d': (N, )
}
}
"""
img_feats, _, _ = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
bbox_list = [dict() for _ in range(len(img_metas))]
bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale)
# bbox_pts: List[dict0, dict1, ...], len = batch_size
# dict: {
# 'boxes_3d': (N, 9)
# 'scores_3d': (N, )
# 'labels_3d': (N, )
# }
for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
result_dict['pts_bbox'] = pts_bbox
return bbox_list
def forward_dummy(self,
points=None,
img_metas=None,
img_inputs=None,
**kwargs):
img_feats, _, _ = self.extract_feat(
points, img=img_inputs, img_metas=img_metas, **kwargs)
assert self.with_pts_bbox
outs = self.pts_bbox_head(img_feats)
return outs
# Copyright (c) Phigent Robotics. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet3d.models import DETECTORS
from mmdet3d.models import builder
from .bevdet import BEVDet
@DETECTORS.register_module()
class BEVDet4D(BEVDet):
r"""BEVDet4D paradigm for multi-camera 3D object detection.
Please refer to the `paper <https://arxiv.org/abs/2203.17054>`_
Args:
pre_process (dict | None): Configuration dict of BEV pre-process net.
align_after_view_transfromation (bool): Whether to align the BEV
Feature after view transformation. By default, the BEV feature of
the previous frame is aligned during the view transformation.
num_adj (int): Number of adjacent frames.
with_prev (bool): Whether to set the BEV feature of previous frame as
all zero. By default, False.
"""
def __init__(self,
pre_process=None,
align_after_view_transfromation=False,
num_adj=1,
with_prev=True,
**kwargs):
super(BEVDet4D, self).__init__(**kwargs)
self.pre_process = pre_process is not None
if self.pre_process:
self.pre_process_net = builder.build_backbone(pre_process)
self.align_after_view_transfromation = align_after_view_transfromation
self.num_frame = num_adj + 1
self.with_prev = with_prev
self.grid = None
def gen_grid(self, input, sensor2keyegos, bda, bda_adj=None):
"""
Args:
input: (B, C, Dy, Dx) bev_feat
sensor2keyegos: List[
curr_sensor-->key_ego: (B, N_views, 4, 4)
prev_sensor-->key_ego: (B, N_views, 4, 4)
]
bda: (B, 3, 3)
bda_adj: None
Returns:
grid: (B, Dy, Dx, 2)
"""
B, C, H, W = input.shape
v = sensor2keyegos[0].shape[0] # N_views
if self.grid is None:
# generate grid
xs = torch.linspace(
0, W - 1, W, dtype=input.dtype,
device=input.device).view(1, W).expand(H, W) # (Dy, Dx)
ys = torch.linspace(
0, H - 1, H, dtype=input.dtype,
device=input.device).view(H, 1).expand(H, W) # (Dy, Dx)
grid = torch.stack((xs, ys, torch.ones_like(xs)), -1) # (Dy, Dx, 3) 3: (x, y, 1)
self.grid = grid
else:
grid = self.grid
# (Dy, Dx, 3) --> (1, Dy, Dx, 3) --> (B, Dy, Dx, 3) --> (B, Dy, Dx, 3, 1)) 3: (grid_x, grid_y, 1)
grid = grid.view(1, H, W, 3).expand(B, H, W, 3).view(B, H, W, 3, 1)
curr_sensor2keyego = sensor2keyegos[0][:, 0:1, :, :] # (B, 1, 4, 4)
prev_sensor2keyego = sensor2keyegos[1][:, 0:1, :, :] # (B, 1, 4, 4)
# add bev data augmentation
bda_ = torch.zeros((B, 1, 4, 4), dtype=grid.dtype).to(grid) # (B, 1, 4, 4)
bda_[:, :, :3, :3] = bda.unsqueeze(1)
bda_[:, :, 3, 3] = 1
curr_sensor2keyego = bda_.matmul(curr_sensor2keyego) # (B, 1, 4, 4)
if bda_adj is not None:
bda_ = torch.zeros((B, 1, 4, 4), dtype=grid.dtype).to(grid)
bda_[:, :, :3, :3] = bda_adj.unsqueeze(1)
bda_[:, :, 3, 3] = 1
prev_sensor2keyego = bda_.matmul(prev_sensor2keyego) # (B, 1, 4, 4)
# transformation from current ego frame to adjacent ego frame
# key_ego --> prev_cam_front --> prev_ego
keyego2adjego = curr_sensor2keyego.matmul(torch.inverse(prev_sensor2keyego))
keyego2adjego = keyego2adjego.unsqueeze(dim=1) # (B, 1, 1, 4, 4)
# (B, 1, 1, 3, 3)
keyego2adjego = keyego2adjego[..., [True, True, False, True], :][..., [True, True, False, True]]
# x = grid_x * vx + x_min; y = grid_y * vy + y_min;
# feat2bev:
# [[vx, 0, x_min],
# [0, vy, y_min],
# [0, 0, 1 ]]
feat2bev = torch.zeros((3, 3), dtype=grid.dtype).to(grid)
feat2bev[0, 0] = self.img_view_transformer.grid_interval[0]
feat2bev[1, 1] = self.img_view_transformer.grid_interval[1]
feat2bev[0, 2] = self.img_view_transformer.grid_lower_bound[0]
feat2bev[1, 2] = self.img_view_transformer.grid_lower_bound[1]
feat2bev[2, 2] = 1
feat2bev = feat2bev.view(1, 3, 3) # (1, 3, 3)
# curr_feat_grid --> key ego --> prev_cam --> prev_ego --> prev_feat_grid
tf = torch.inverse(feat2bev).matmul(keyego2adjego).matmul(feat2bev) # (B, 1, 1, 3, 3)
grid = tf.matmul(grid) # (B, Dy, Dx, 3, 1) 3: (grid_x, grid_y, 1)
normalize_factor = torch.tensor([W - 1.0, H - 1.0],
dtype=input.dtype,
device=input.device) # (2, )
# (B, Dy, Dx, 2)
grid = grid[:, :, :, :2, 0] / normalize_factor.view(1, 1, 1, 2) * 2.0 - 1.0
return grid
@force_fp32()
def shift_feature(self, input, sensor2keyegos, bda, bda_adj=None):
"""
Args:
input: (B, C, Dy, Dx) bev_feat
sensor2keyegos: List[
curr_sensor-->key_ego: (B, N_views, 4, 4)
prev_sensor-->key_ego: (B, N_views, 4, 4)
]
bda: (B, 3, 3)
bda_adj: None
Returns:
output: aligned bev feat (B, C, Dy, Dx).
"""
grid = self.gen_grid(input, sensor2keyegos, bda, bda_adj=bda_adj) # grid: (B, Dy, Dx, 2), 介于(-1, 1)
output = F.grid_sample(input, grid.to(input.dtype), align_corners=True) # (B, C, Dy, Dx)
return output
def prepare_bev_feat(self, img, sensor2egos, ego2globals, intrin, post_rot, post_tran,
bda, mlp_input):
"""
Args:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
mlp_input:
Returns:
bev_feat: (B, C, Dy, Dx)
depth: (B*N, D, fH, fW)
"""
x, _ = self.image_encoder(img) # x: (B, N, C, fH, fW)
# bev_feat: (B, C * Dz(=1), Dy, Dx)
# depth: (B * N, D, fH, fW)
bev_feat, depth = self.img_view_transformer(
[x, sensor2egos, ego2globals, intrin, post_rot, post_tran, bda, mlp_input])
if self.pre_process:
bev_feat = self.pre_process_net(bev_feat)[0] # (B, C, Dy, Dx)
return bev_feat, depth
def extract_img_feat_sequential(self, inputs, feat_prev):
"""
Args:
inputs:
curr_img: (1, N_views, 3, H, W)
sensor2keyegos_curr: (N_prev, N_views, 4, 4)
ego2globals_curr: (N_prev, N_views, 4, 4)
intrins: (1, N_views, 3, 3)
sensor2keyegos_prev: (N_prev, N_views, 4, 4)
ego2globals_prev: (N_prev, N_views, 4, 4)
post_rots: (1, N_views, 3, 3)
post_trans: (1, N_views, 3, )
bda_curr: (N_prev, 3, 3)
feat_prev: (N_prev, C, Dy, Dx)
Returns:
"""
imgs, sensor2keyegos_curr, ego2globals_curr, intrins = inputs[:4]
sensor2keyegos_prev, _, post_rots, post_trans, bda = inputs[4:]
bev_feat_list = []
mlp_input = self.img_view_transformer.get_mlp_input(
sensor2keyegos_curr[0:1, ...], ego2globals_curr[0:1, ...],
intrins, post_rots, post_trans, bda[0:1, ...])
inputs_curr = (imgs, sensor2keyegos_curr[0:1, ...],
ego2globals_curr[0:1, ...], intrins, post_rots,
post_trans, bda[0:1, ...], mlp_input)
# (1, C, Dx, Dy), (1*N, D, fH, fW)
bev_feat, depth = self.prepare_bev_feat(*inputs_curr)
bev_feat_list.append(bev_feat)
# align the feat_prev
_, C, H, W = feat_prev.shape
# feat_prev: (N_prev, C, Dy, Dx)
feat_prev = \
self.shift_feature(feat_prev, # (N_prev, C, Dy, Dx)
[sensor2keyegos_curr, # (N_prev, N_views, 4, 4)
sensor2keyegos_prev], # (N_prev, N_views, 4, 4)
bda # (N_prev, 3, 3)
)
bev_feat_list.append(feat_prev.view(1, (self.num_frame - 1) * C, H, W)) # (1, N_prev*C, Dy, Dx)
bev_feat = torch.cat(bev_feat_list, dim=1) # (1, N_frames*C, Dy, Dx)
x = self.bev_encoder(bev_feat)
return [x], depth
def prepare_inputs(self, img_inputs, stereo=False):
"""
Args:
img_inputs:
imgs: (B, N, 3, H, W) # N = 6 * (N_history + 1)
sensor2egos: (B, N, 4, 4)
ego2globals: (B, N, 4, 4)
intrins: (B, N, 3, 3)
post_rots: (B, N, 3, 3)
post_trans: (B, N, 3)
bda_rot: (B, 3, 3)
stereo: bool
Returns:
imgs: List[(B, N_views, C, H, W), (B, N_views, C, H, W), ...] len = N_frames
sensor2keyegos: List[(B, N_views, 4, 4), (B, N_views, 4, 4), ...]
ego2globals: List[(B, N_views, 4, 4), (B, N_views, 4, 4), ...]
intrins: List[(B, N_views, 3, 3), (B, N_views, 3, 3), ...]
post_rots: List[(B, N_views, 3, 3), (B, N_views, 3, 3), ...]
post_trans: List[(B, N_views, 3), (B, N_views, 3), ...]
bda: (B, 3, 3)
"""
B, N, C, H, W = img_inputs[0].shape
N = N // self.num_frame # N_views = 6
imgs = img_inputs[0].view(B, N, self.num_frame, C, H, W) # (B, N_views, N_frames, C, H, W)
imgs = torch.split(imgs, 1, 2)
imgs = [t.squeeze(2) for t in imgs] # List[(B, N_views, C, H, W), (B, N_views, C, H, W), ...]
sensor2egos, ego2globals, intrins, post_rots, post_trans, bda = \
img_inputs[1:7]
sensor2egos = sensor2egos.view(B, self.num_frame, N, 4, 4)
ego2globals = ego2globals.view(B, self.num_frame, N, 4, 4)
# calculate the transformation from sensor to key ego
# key_ego --> global (B, 1, 1, 4, 4)
keyego2global = ego2globals[:, 0, 0, ...].unsqueeze(1).unsqueeze(1)
# global --> key_ego (B, 1, 1, 4, 4)
global2keyego = torch.inverse(keyego2global.double())
# sensor --> ego --> global --> key_ego
sensor2keyegos = \
global2keyego @ ego2globals.double() @ sensor2egos.double() # (B, N_frames, N_views, 4, 4)
sensor2keyegos = sensor2keyegos.float()
# -------------------- for stereo --------------------------
curr2adjsensor = None
if stereo:
# (B, N_frames, N_views, 4, 4), (B, N_frames, N_views, 4, 4)
sensor2egos_cv, ego2globals_cv = sensor2egos, ego2globals
sensor2egos_curr = \
sensor2egos_cv[:, :self.temporal_frame, ...].double() # (B, N_temporal=2, N_views, 4, 4)
ego2globals_curr = \
ego2globals_cv[:, :self.temporal_frame, ...].double() # (B, N_temporal=2, N_views, 4, 4)
sensor2egos_adj = \
sensor2egos_cv[:, 1:self.temporal_frame + 1, ...].double() # (B, N_temporal=2, N_views, 4, 4)
ego2globals_adj = \
ego2globals_cv[:, 1:self.temporal_frame + 1, ...].double() # (B, N_temporal=2, N_views, 4, 4)
# curr_sensor --> curr_ego --> global --> prev_ego --> prev_sensor
curr2adjsensor = \
torch.inverse(ego2globals_adj @ sensor2egos_adj) \
@ ego2globals_curr @ sensor2egos_curr # (B, N_temporal=2, N_views, 4, 4)
curr2adjsensor = curr2adjsensor.float() # (B, N_temporal=2, N_views, 4, 4)
curr2adjsensor = torch.split(curr2adjsensor, 1, 1)
curr2adjsensor = [p.squeeze(1) for p in curr2adjsensor]
curr2adjsensor.extend([None for _ in range(self.extra_ref_frames)])
# curr2adjsensor: List[(B, N_views, 4, 4), (B, N_views, 4, 4), None]
assert len(curr2adjsensor) == self.num_frame
# -------------------- for stereo --------------------------
extra = [
sensor2keyegos, # (B, N_frames, N_views, 4, 4)
ego2globals, # (B, N_frames, N_views, 4, 4)
intrins.view(B, self.num_frame, N, 3, 3), # (B, N_frames, N_views, 3, 3)
post_rots.view(B, self.num_frame, N, 3, 3), # (B, N_frames, N_views, 3, 3)
post_trans.view(B, self.num_frame, N, 3) # (B, N_frames, N_views, 3)
]
extra = [torch.split(t, 1, 1) for t in extra]
extra = [[p.squeeze(1) for p in t] for t in extra]
sensor2keyegos, ego2globals, intrins, post_rots, post_trans = extra
return imgs, sensor2keyegos, ego2globals, intrins, post_rots, post_trans, \
bda, curr2adjsensor
def extract_img_feat(self,
img_inputs,
img_metas,
pred_prev=False,
sequential=False,
**kwargs):
"""
Args:
img_inputs:
imgs: (B, N, 3, H, W) # N = 6 * (N_history + 1)
sensor2egos: (B, N, 4, 4)
ego2globals: (B, N, 4, 4)
intrins: (B, N, 3, 3)
post_rots: (B, N, 3, 3)
post_trans: (B, N, 3)
bda_rot: (B, 3, 3)
img_metas:
**kwargs:
Returns:
x: [(B, C', H', W'), ]
depth: (B*N_views, D, fH, fW)
"""
if sequential:
return self.extract_img_feat_sequential(img_inputs, kwargs['feat_prev'])
imgs, sensor2keyegos, ego2globals, intrins, post_rots, post_trans, \
bda, _ = self.prepare_inputs(img_inputs)
"""Extract features of images."""
bev_feat_list = []
depth_list = []
key_frame = True # back propagation for key frame only
for img, sensor2keyego, ego2global, intrin, post_rot, post_tran in zip(
imgs, sensor2keyegos, ego2globals, intrins, post_rots, post_trans):
if key_frame or self.with_prev:
if self.align_after_view_transfromation:
sensor2keyego, ego2global = sensor2keyegos[0], ego2globals[0]
mlp_input = self.img_view_transformer.get_mlp_input(
sensor2keyegos[0], ego2globals[0], intrin, post_rot, post_tran, bda) # (B, N_views, 27)
inputs_curr = (img, sensor2keyego, ego2global, intrin, post_rot,
post_tran, bda, mlp_input)
if key_frame:
# bev_feat: (B, C, Dy, Dx)
# depth: (B*N_views, D, fH, fW)
bev_feat, depth = self.prepare_bev_feat(*inputs_curr)
else:
with torch.no_grad():
bev_feat, depth = self.prepare_bev_feat(*inputs_curr)
else:
# https://github.com/HuangJunJie2017/BEVDet/issues/275
bev_feat = torch.zeros_like(bev_feat_list[0])
depth = None
bev_feat_list.append(bev_feat)
depth_list.append(depth)
key_frame = False
# bev_feat_list: List[(B, C, Dy, Dx), (B, C, Dy, Dx), ...]
# depth_list: List[(B*N_views, D, fH, fW), (B*N_views, D, fH, fW), ...]
if pred_prev:
assert self.align_after_view_transfromation
assert sensor2keyegos[0].shape[0] == 1 # batch_size = 1
feat_prev = torch.cat(bev_feat_list[1:], dim=0)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
ego2globals_curr = \
ego2globals[0].repeat(self.num_frame - 1, 1, 1, 1)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
sensor2keyegos_curr = \
sensor2keyegos[0].repeat(self.num_frame - 1, 1, 1, 1)
ego2globals_prev = torch.cat(ego2globals[1:], dim=0) # (N_prev, N_views, 4, 4)
sensor2keyegos_prev = torch.cat(sensor2keyegos[1:], dim=0) # (N_prev, N_views, 4, 4)
bda_curr = bda.repeat(self.num_frame - 1, 1, 1) # (N_prev, 3, 3)
return feat_prev, [imgs[0], # (1, N_views, 3, H, W)
sensor2keyegos_curr, # (N_prev, N_views, 4, 4)
ego2globals_curr, # (N_prev, N_views, 4, 4)
intrins[0], # (1, N_views, 3, 3)
sensor2keyegos_prev, # (N_prev, N_views, 4, 4)
ego2globals_prev, # (N_prev, N_views, 4, 4)
post_rots[0], # (1, N_views, 3, 3)
post_trans[0], # (1, N_views, 3, )
bda_curr] # (N_prev, 3, 3)
if self.align_after_view_transfromation:
for adj_id in range(1, self.num_frame):
bev_feat_list[adj_id] = self.shift_feature(
bev_feat_list[adj_id], # (B, C, Dy, Dx)
[sensor2keyegos[0], # (B, N_views, 4, 4)
sensor2keyegos[adj_id] # (B, N_views, 4, 4)
],
bda # (B, 3, 3)
) # (B, C, Dy, Dx)
bev_feat = torch.cat(bev_feat_list, dim=1) # (B, N_frames*C, Dy, Dx)
x = self.bev_encoder(bev_feat)
return [x], depth_list[0]
# Copyright (c) Phigent Robotics. All rights reserved.
from ...ops import TRTBEVPoolv2
from .bevdet import BEVDet
from .bevdepth import BEVDepth
from .bevdepth4d import BEVDepth4D
from .bevstereo4d import BEVStereo4D
from mmdet3d.models import DETECTORS
from mmdet3d.models.builder import build_head
import torch.nn.functional as F
from mmdet3d.core import bbox3d2result
import numpy as np
from multiprocessing.dummy import Pool as ThreadPool
from ...ops import nearest_assign
# pool = ThreadPool(processes=4) # 创建线程池
# for pano
grid_config_occ = {
'x': [-40, 40, 0.4],
'y': [-40, 40, 0.4],
'z': [-1, 5.4, 6.4],
'depth': [1.0, 45.0, 1.0],
}
# det
det_class_name = ['car', 'truck', 'trailer', 'bus', 'construction_vehicle',
'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone',
'barrier']
# occ
occ_class_names = [
'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',
'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',
'driveable_surface', 'other_flat', 'sidewalk',
'terrain', 'manmade', 'vegetation', 'free'
]
det_ind = [2, 3, 4, 5, 6, 7, 9, 10]
occ_ind = [5, 3, 0, 4, 6, 7, 2, 1]
detind2occind = {
0:4,
1:10,
2:9,
3:3,
4:5,
5:2,
6:6,
7:7,
8:8,
9:1,
}
occind2detind = {
4:0,
10:1,
9:2,
3:3,
5:4,
2:5,
6:6,
7:7,
8:8,
1:9,
}
occind2detind_cuda = [-1, -1, 5, 3, 0, 4, 6, 7, -1, 2, 1]
inst_occ = np.ones([200, 200, 16])*0
import torch
X1, Y1, Z1 = 200, 200, 16
coords_x = torch.arange(X1).float()
coords_y = torch.arange(Y1).float()
coords_z = torch.arange(Z1).float()
coords = torch.stack(torch.meshgrid([coords_x, coords_y, coords_z])).permute(1, 2, 3, 0) # W, H, D, 3
# coords = coords.cpu().numpy()
st = [grid_config_occ['x'][0], grid_config_occ['y'][0], grid_config_occ['z'][0]]
sx = [grid_config_occ['x'][2], grid_config_occ['y'][2], 0.4]
@DETECTORS.register_module()
class BEVDetOCC(BEVDet):
def __init__(self,
occ_head=None,
upsample=False,
**kwargs):
super(BEVDetOCC, self).__init__(**kwargs)
self.occ_head = build_head(occ_head)
self.pts_bbox_head = None
self.upsample = upsample
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
losses = dict()
voxel_semantics = kwargs['voxel_semantics'] # (B, Dx, Dy, Dz)
mask_camera = kwargs['mask_camera'] # (B, Dx, Dy, Dz)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
loss_occ = self.forward_occ_train(occ_bev_feature, voxel_semantics, mask_camera)
losses.update(loss_occ)
return losses
@torch.compile(mode="reduce-overhead")
def forward_occ_train(self, img_feats, voxel_semantics, mask_camera):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs = self.occ_head(img_feats)
# assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
loss_occ = self.occ_head.loss(
outs, # (B, Dx, Dy, Dz, n_cls)
voxel_semantics, # (B, Dx, Dy, Dz)
mask_camera, # (B, Dx, Dy, Dz)
)
return loss_occ
def simple_test(self,
points,
img_metas,
img=None,
rescale=False,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, _, _ = self.extract_feat(
points, img_inputs=img, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
occ_list = self.simple_test_occ(occ_bev_feature, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_list
def simple_test_occ(self, img_feats, img_metas=None):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs = self.occ_head(img_feats)
if not hasattr(self.occ_head, "get_occ_gpu"):
occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
else:
occ_preds = self.occ_head.get_occ_gpu(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_preds
def forward_dummy(self,
points=None,
img_metas=None,
img_inputs=None,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs = self.occ_head(occ_bev_feature)
return outs
@DETECTORS.register_module()
class BEVDepthOCC(BEVDepth):
def __init__(self,
occ_head=None,
upsample=False,
**kwargs):
super(BEVDepthOCC, self).__init__(**kwargs)
self.occ_head = build_head(occ_head)
self.pts_bbox_head = None
self.upsample = upsample
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
losses = dict()
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses['loss_depth'] = loss_depth
voxel_semantics = kwargs['voxel_semantics'] # (B, Dx, Dy, Dz)
mask_camera = kwargs['mask_camera'] # (B, Dx, Dy, Dz)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
loss_occ = self.forward_occ_train(occ_bev_feature, voxel_semantics, mask_camera)
losses.update(loss_occ)
return losses
def forward_occ_train(self, img_feats, voxel_semantics, mask_camera):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs = self.occ_head(img_feats)
# assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
loss_occ = self.occ_head.loss(
outs, # (B, Dx, Dy, Dz, n_cls)
voxel_semantics, # (B, Dx, Dy, Dz)
mask_camera, # (B, Dx, Dy, Dz)
)
return loss_occ
def simple_test(self,
points,
img_metas,
img=None,
rescale=False,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, _, _ = self.extract_feat(
points, img_inputs=img, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
occ_list = self.simple_test_occ(occ_bev_feature, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_list
def simple_test_occ(self, img_feats, img_metas=None):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs = self.occ_head(img_feats)
# occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
occ_preds = self.occ_head.get_occ_gpu(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_preds
def forward_dummy(self,
points=None,
img_metas=None,
img_inputs=None,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs = self.occ_head(occ_bev_feature)
return outs
@DETECTORS.register_module()
class BEVDepthPano(BEVDepthOCC):
def __init__(self,
aux_centerness_head=None,
**kwargs):
super(BEVDepthPano, self).__init__(**kwargs)
self.aux_centerness_head = None
if aux_centerness_head:
train_cfg = kwargs['train_cfg']
test_cfg = kwargs['test_cfg']
pts_train_cfg = train_cfg.pts if train_cfg else None
aux_centerness_head.update(train_cfg=pts_train_cfg)
pts_test_cfg = test_cfg.pts if test_cfg else None
aux_centerness_head.update(test_cfg=pts_test_cfg)
self.aux_centerness_head = build_head(aux_centerness_head)
if 'inst_class_ids' in kwargs:
self.inst_class_ids = kwargs['inst_class_ids']
else:
self.inst_class_ids = [2, 3, 4, 5, 6, 7, 9, 10]
X1, Y1, Z1 = 200, 200, 16
coords_x = torch.arange(X1).float()
coords_y = torch.arange(Y1).float()
coords_z = torch.arange(Z1).float()
self.coords = torch.stack(torch.meshgrid([coords_x, coords_y, coords_z])).permute(1, 2, 3, 0) # W, H, D, 3
self.st = torch.tensor([grid_config_occ['x'][0], grid_config_occ['y'][0], grid_config_occ['z'][0]])
self.sx = torch.tensor([grid_config_occ['x'][2], grid_config_occ['y'][2], 0.4])
self.is_to_d = False
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
losses = dict()
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses['loss_depth'] = loss_depth
voxel_semantics = kwargs['voxel_semantics'] # (B, Dx, Dy, Dz)
mask_camera = kwargs['mask_camera'] # (B, Dx, Dy, Dz)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
loss_occ = self.forward_occ_train(occ_bev_feature, voxel_semantics, mask_camera)
losses.update(loss_occ)
losses_aux_centerness = self.forward_aux_centerness_train([occ_bev_feature], gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore)
losses.update(losses_aux_centerness)
return losses
def forward_aux_centerness_train(self,
pts_feats,
gt_bboxes_3d,
gt_labels_3d,
img_metas,
gt_bboxes_ignore=None):
outs = self.aux_centerness_head(pts_feats)
loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
losses = self.aux_centerness_head.loss(*loss_inputs)
return losses
def simple_test_aux_centerness(self, x, img_metas, rescale=False, **kwargs):
"""Test function of point cloud branch."""
# outs = self.aux_centerness_head(x)
tx = self.aux_centerness_head.shared_conv(x[0]) # (B, C'=share_conv_channel, H, W)
outs_inst_center_reg = self.aux_centerness_head.task_heads[0].reg(tx)
outs_inst_center_height = self.aux_centerness_head.task_heads[0].height(tx)
outs_inst_center_heatmap = self.aux_centerness_head.task_heads[0].heatmap(tx)
outs = ([{
"reg" : outs_inst_center_reg,
"height" : outs_inst_center_height,
"heatmap" : outs_inst_center_heatmap,
}],)
# # bbox_list = self.aux_centerness_head.get_bboxes(
# # outs, img_metas, rescale=rescale)
# # bbox_results = [
# # bbox3d2result(bboxes, scores, labels)
# # for bboxes, scores, labels in bbox_list
# # ]
ins_cen_list = self.aux_centerness_head.get_centers(
outs, img_metas, rescale=rescale)
# return bbox_results, ins_cen_list
return None, ins_cen_list
def simple_test(self,
points,
img_metas,
img=None,
rescale=False,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
result_list = [dict() for _ in range(len(img_metas))]
img_feats, _, _ = self.extract_feat(
points, img_inputs=img, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
w_pano = kwargs['w_pano'] if 'w_pano' in kwargs else True
if w_pano == True:
bbox_pts, ins_cen_list = self.simple_test_aux_centerness([occ_bev_feature], img_metas, rescale=rescale, **kwargs)
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
occ_list = self.simple_test_occ(occ_bev_feature, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
for result_dict, occ_pred in zip(result_list, occ_list):
result_dict['pred_occ'] = occ_pred
w_panoproc = kwargs['w_panoproc'] if 'w_panoproc' in kwargs else True # 37.53 ms
if w_panoproc == True:
# # for pano
inst_xyz = ins_cen_list[0][0]
if self.is_to_d == False:
self.st = self.st.to(inst_xyz)
self.sx = self.sx.to(inst_xyz)
self.coords = self.coords.to(inst_xyz)
self.is_to_d = True
inst_xyz = ((inst_xyz - self.st) / self.sx).int()
inst_cls = ins_cen_list[2][0].int()
inst_num = 18 # 37.62 ms
# inst_occ = torch.tensor(occ_pred).to(inst_cls)
# inst_occ = occ_pred.clone().detach()
inst_occ = occ_pred.clone().detach() # 37.61 ms
if len(inst_cls) > 0:
cls_sort, indices = inst_cls.sort()
l2s = {}
if len(inst_cls) == 1:
l2s[cls_sort[0].item()] = 0
l2s[cls_sort[0].item()] = 0
# # tind_list = cls_sort[1:] - cls_sort[:-1]!=0
# # for tind in range(len(tind_list)):
# # if tind_list[tind] == True:
# # l2s[cls_sort[1+tind].item()] = tind + 1
tind_list = (cls_sort[1:] - cls_sort[:-1])!=0
if tind_list.__len__() > 0:
for tind in torch.range(0,len(tind_list)-1)[tind_list]:
l2s[cls_sort[1+int(tind.item())].item()] = int(tind.item()) + 1
is_cuda = True
# is_cuda = False
if is_cuda == True:
inst_id_list = indices + inst_num
l2s_key = indices.new_tensor([detind2occind[k] for k in l2s.keys()]).to(torch.int)
inst_occ = nearest_assign(
occ_pred.to(torch.int),
l2s_key.to(torch.int),
indices.new_tensor(occind2detind_cuda).to(torch.int),
inst_cls.to(torch.int),
inst_xyz.to(torch.int),
inst_id_list.to(torch.int)
)
else:
for cls_label_num_in_occ in self.inst_class_ids:
mask = occ_pred == cls_label_num_in_occ
if mask.sum() == 0:
continue
else:
cls_label_num_in_inst = occind2detind[cls_label_num_in_occ]
select_mask = inst_cls==cls_label_num_in_inst
if sum(select_mask) > 0:
indices = self.coords[mask]
inst_index_same_cls = inst_xyz[select_mask]
select_ind = ((indices[:,None,:] - inst_index_same_cls[None,:,:])**2).sum(-1).argmin(axis=1).int()
inst_occ[mask] = select_ind + inst_num + l2s[cls_label_num_in_inst]
result_list[0]['pano_inst'] = inst_occ #.cpu().numpy()
return result_list
@DETECTORS.register_module()
class BEVDepth4DOCC(BEVDepth4D):
def __init__(self,
occ_head=None,
upsample=False,
**kwargs):
super(BEVDepth4DOCC, self).__init__(**kwargs)
self.occ_head = build_head(occ_head)
self.pts_bbox_head = None
self.upsample = upsample
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
losses = dict()
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses['loss_depth'] = loss_depth
voxel_semantics = kwargs['voxel_semantics'] # (B, Dx, Dy, Dz)
mask_camera = kwargs['mask_camera'] # (B, Dx, Dy, Dz)
loss_occ = self.forward_occ_train(img_feats[0], voxel_semantics, mask_camera)
losses.update(loss_occ)
return losses
def forward_occ_train(self, img_feats, voxel_semantics, mask_camera):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs = self.occ_head(img_feats)
assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
loss_occ = self.occ_head.loss(
outs, # (B, Dx, Dy, Dz, n_cls)
voxel_semantics, # (B, Dx, Dy, Dz)
mask_camera, # (B, Dx, Dy, Dz)
)
return loss_occ
def simple_test(self,
points,
img_metas,
img=None,
rescale=False,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, _, _ = self.extract_feat(
points, img_inputs=img, img_metas=img_metas, **kwargs)
occ_list = self.simple_test_occ(img_feats[0], img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_list
def simple_test_occ(self, img_feats, img_metas=None):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs = self.occ_head(img_feats)
# occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
occ_preds = self.occ_head.get_occ_gpu(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_preds
def forward_dummy(self,
points=None,
img_metas=None,
img_inputs=None,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs = self.occ_head(occ_bev_feature)
return outs
@DETECTORS.register_module()
class BEVDepth4DPano(BEVDepth4DOCC):
def __init__(self,
aux_centerness_head=None,
**kwargs):
super(BEVDepth4DPano, self).__init__(**kwargs)
self.aux_centerness_head = None
if aux_centerness_head:
train_cfg = kwargs['train_cfg']
test_cfg = kwargs['test_cfg']
pts_train_cfg = train_cfg.pts if train_cfg else None
aux_centerness_head.update(train_cfg=pts_train_cfg)
pts_test_cfg = test_cfg.pts if test_cfg else None
aux_centerness_head.update(test_cfg=pts_test_cfg)
self.aux_centerness_head = build_head(aux_centerness_head)
if 'inst_class_ids' in kwargs:
self.inst_class_ids = kwargs['inst_class_ids']
else:
self.inst_class_ids = [2, 3, 4, 5, 6, 7, 9, 10]
X1, Y1, Z1 = 200, 200, 16
coords_x = torch.arange(X1).float()
coords_y = torch.arange(Y1).float()
coords_z = torch.arange(Z1).float()
self.coords = torch.stack(torch.meshgrid([coords_x, coords_y, coords_z])).permute(1, 2, 3, 0) # W, H, D, 3
self.st = torch.tensor([grid_config_occ['x'][0], grid_config_occ['y'][0], grid_config_occ['z'][0]])
self.sx = torch.tensor([grid_config_occ['x'][2], grid_config_occ['y'][2], 0.4])
self.is_to_d = False
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
losses = dict()
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses['loss_depth'] = loss_depth
voxel_semantics = kwargs['voxel_semantics'] # (B, Dx, Dy, Dz)
mask_camera = kwargs['mask_camera'] # (B, Dx, Dy, Dz)
loss_occ = self.forward_occ_train(img_feats[0], voxel_semantics, mask_camera)
losses.update(loss_occ)
losses_aux_centerness = self.forward_aux_centerness_train([img_feats[0]], gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore)
losses.update(losses_aux_centerness)
return losses
def forward_aux_centerness_train(self,
pts_feats,
gt_bboxes_3d,
gt_labels_3d,
img_metas,
gt_bboxes_ignore=None):
outs = self.aux_centerness_head(pts_feats)
loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
losses = self.aux_centerness_head.loss(*loss_inputs)
return losses
def simple_test_aux_centerness(self, x, img_metas, rescale=False, **kwargs):
"""Test function of point cloud branch."""
outs = self.aux_centerness_head(x)
bbox_list = self.aux_centerness_head.get_bboxes(
outs, img_metas, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
ins_cen_list = self.aux_centerness_head.get_centers(
outs, img_metas, rescale=rescale)
return bbox_results, ins_cen_list
def simple_test(self,
points,
img_metas,
img=None,
rescale=False,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
result_list = [dict() for _ in range(len(img_metas))]
img_feats, _, _ = self.extract_feat(
points, img_inputs=img, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
w_pano = kwargs['w_pano'] if 'w_pano' in kwargs else True
if w_pano == True:
bbox_pts, ins_cen_list = self.simple_test_aux_centerness([occ_bev_feature], img_metas, rescale=rescale, **kwargs)
occ_list = self.simple_test_occ(occ_bev_feature, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
for result_dict, occ_pred in zip(result_list, occ_list):
result_dict['pred_occ'] = occ_pred
w_panoproc = kwargs['w_panoproc'] if 'w_panoproc' in kwargs else True
if w_panoproc == True:
# # for pano
inst_xyz = ins_cen_list[0][0]
if self.is_to_d == False:
self.st = self.st.to(inst_xyz)
self.sx = self.sx.to(inst_xyz)
self.coords = self.coords.to(inst_xyz)
self.is_to_d = True
inst_xyz = ((inst_xyz - self.st) / self.sx).int()
inst_cls = ins_cen_list[2][0].int()
inst_num = 18 # 37.62 ms
# inst_occ = torch.tensor(occ_pred).to(inst_cls)
# inst_occ = occ_pred.clone().detach()
inst_occ = occ_pred.clone().detach() # 37.61 ms
if len(inst_cls) > 0:
cls_sort, indices = inst_cls.sort()
l2s = {}
if len(inst_cls) == 1:
l2s[cls_sort[0].item()] = 0
l2s[cls_sort[0].item()] = 0
# # tind_list = cls_sort[1:] - cls_sort[:-1]!=0
# # for tind in range(len(tind_list)):
# # if tind_list[tind] == True:
# # l2s[cls_sort[1+tind].item()] = tind + 1
tind_list = (cls_sort[1:] - cls_sort[:-1])!=0
if tind_list.__len__() > 0:
for tind in torch.range(0,len(tind_list)-1)[tind_list]:
l2s[cls_sort[1+int(tind.item())].item()] = int(tind.item()) + 1
is_cuda = True
# is_cuda = False
if is_cuda == True:
inst_id_list = indices + inst_num
l2s_key = indices.new_tensor([detind2occind[k] for k in l2s.keys()]).to(torch.int)
inst_occ = nearest_assign(
occ_pred.to(torch.int),
l2s_key.to(torch.int),
indices.new_tensor(occind2detind_cuda).to(torch.int),
inst_cls.to(torch.int),
inst_xyz.to(torch.int),
inst_id_list.to(torch.int)
)
else:
for cls_label_num_in_occ in self.inst_class_ids:
mask = occ_pred == cls_label_num_in_occ
if mask.sum() == 0:
continue
else:
cls_label_num_in_inst = occind2detind[cls_label_num_in_occ]
select_mask = inst_cls==cls_label_num_in_inst
if sum(select_mask) > 0:
indices = self.coords[mask]
inst_index_same_cls = inst_xyz[select_mask]
select_ind = ((indices[:,None,:] - inst_index_same_cls[None,:,:])**2).sum(-1).argmin(axis=1).int()
inst_occ[mask] = select_ind + inst_num + l2s[cls_label_num_in_inst]
result_list[0]['pano_inst'] = inst_occ #.cpu().numpy()
return result_list
@DETECTORS.register_module()
class BEVStereo4DOCC(BEVStereo4D):
def __init__(self,
occ_head=None,
upsample=False,
**kwargs):
super(BEVStereo4DOCC, self).__init__(**kwargs)
self.occ_head = build_head(occ_head)
self.pts_bbox_head = None
self.upsample = upsample
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
losses = dict()
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses['loss_depth'] = loss_depth
voxel_semantics = kwargs['voxel_semantics'] # (B, Dx, Dy, Dz)
mask_camera = kwargs['mask_camera'] # (B, Dx, Dy, Dz)
loss_occ = self.forward_occ_train(img_feats[0], voxel_semantics, mask_camera)
losses.update(loss_occ)
return losses
def forward_occ_train(self, img_feats, voxel_semantics, mask_camera):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs = self.occ_head(img_feats)
assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
loss_occ = self.occ_head.loss(
outs, # (B, Dx, Dy, Dz, n_cls)
voxel_semantics, # (B, Dx, Dy, Dz)
mask_camera, # (B, Dx, Dy, Dz)
)
return loss_occ
def simple_test(self,
points,
img_metas,
img=None,
rescale=False,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, _, _ = self.extract_feat(
points, img_inputs=img, img_metas=img_metas, **kwargs)
occ_list = self.simple_test_occ(img_feats[0], img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_list
def simple_test_occ(self, img_feats, img_metas=None):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs = self.occ_head(img_feats)
# occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
occ_preds = self.occ_head.get_occ_gpu(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_preds
def forward_dummy(self,
points=None,
img_metas=None,
img_inputs=None,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs = self.occ_head(occ_bev_feature)
return outs
@DETECTORS.register_module()
class BEVDetOCCTRT(BEVDetOCC):
def __init__(self,
wocc=True,
wdet3d=True,
uni_train=True,
**kwargs):
super(BEVDetOCCTRT, self).__init__(**kwargs)
self.wocc = wocc
self.wdet3d = wdet3d
self.uni_train = uni_train
def result_serialize(self, outs_det3d=None, outs_occ=None):
outs_ = []
if outs_det3d is not None:
for out in outs_det3d:
for key in ['reg', 'height', 'dim', 'rot', 'vel', 'heatmap']:
outs_.append(out[0][key])
if outs_occ is not None:
outs_.append(outs_occ)
return outs_
def result_deserialize(self, outs):
outs_ = []
keys = ['reg', 'height', 'dim', 'rot', 'vel', 'heatmap']
for head_id in range(len(outs) // 6):
outs_head = [dict()]
for kid, key in enumerate(keys):
outs_head[0][key] = outs[head_id * 6 + kid]
outs_.append(outs_head)
return outs_
def forward_part1(
self,
img,
):
x = self.img_backbone(img)
x = self.img_neck(x)
x = self.img_view_transformer.depth_net(x[0])
depth = x[:, :self.img_view_transformer.D].softmax(dim=1)
tran_feat = x[:, self.img_view_transformer.D:(
self.img_view_transformer.D +
self.img_view_transformer.out_channels)]
tran_feat = tran_feat.permute(0, 2, 3, 1)
# depth = depth.reshape(-1)
# tran_feat = tran_feat.flatten(0,2)
return tran_feat.flatten(0,2), depth.reshape(-1)
def forward_part2(
self,
tran_feat,
depth,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
):
tran_feat = tran_feat.reshape(6, 16, 44, 64)
depth = depth.reshape(6, 16, 44, 44)
x = TRTBEVPoolv2.apply(depth.contiguous(), tran_feat.contiguous(),
ranks_depth, ranks_feat, ranks_bev,
interval_starts, interval_lengths,
int(self.img_view_transformer.grid_size[0].item()),
int(self.img_view_transformer.grid_size[1].item()),
int(self.img_view_transformer.grid_size[2].item())
) # -> [1, 64, 200, 200]
return x.reshape(-1)
def forward_part3(
self,
x
):
x = x.reshape(1, 200, 200, 64)
x = x.permute(0, 3, 1, 2).contiguous()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature = self.img_bev_encoder_backbone(x)
occ_bev_feature = self.img_bev_encoder_neck(bev_feature)
outs_occ = None
if self.wocc == True:
if self.uni_train == True:
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs_occ = self.occ_head(occ_bev_feature)
outs_det3d = None
if self.wdet3d == True:
outs_det3d = self.pts_bbox_head([occ_bev_feature])
outs = self.result_serialize(outs_det3d, outs_occ)
return outs
def forward_ori(
self,
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
):
x = self.img_backbone(img)
x = self.img_neck(x)
x = self.img_view_transformer.depth_net(x[0])
depth = x[:, :self.img_view_transformer.D].softmax(dim=1)
tran_feat = x[:, self.img_view_transformer.D:(
self.img_view_transformer.D +
self.img_view_transformer.out_channels)]
tran_feat = tran_feat.permute(0, 2, 3, 1)
x = TRTBEVPoolv2.apply(depth.contiguous(), tran_feat.contiguous(),
ranks_depth, ranks_feat, ranks_bev,
interval_starts, interval_lengths,
int(self.img_view_transformer.grid_size[0].item()),
int(self.img_view_transformer.grid_size[1].item()),
int(self.img_view_transformer.grid_size[2].item())
)
x = x.permute(0, 3, 1, 2).contiguous()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature = self.img_bev_encoder_backbone(x)
occ_bev_feature = self.img_bev_encoder_neck(bev_feature)
outs_occ = None
if self.wocc == True:
if self.uni_train == True:
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs_occ = self.occ_head(occ_bev_feature)
outs_det3d = None
if self.wdet3d == True:
outs_det3d = self.pts_bbox_head([occ_bev_feature])
outs = self.result_serialize(outs_det3d, outs_occ)
return outs
def forward_with_argmax(
self,
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
):
outs = self.forward_ori(
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
)
pred_occ_label = outs[0].argmax(-1)
return pred_occ_label
def get_bev_pool_input(self, input):
input = self.prepare_inputs(input)
coor = self.img_view_transformer.get_lidar_coor(*input[1:7])
return self.img_view_transformer.voxel_pooling_prepare_v2(coor)
@DETECTORS.register_module()
class BEVDepthOCCTRT(BEVDetOCC):
def __init__(self,
wocc=True,
wdet3d=True,
uni_train=True,
**kwargs):
super(BEVDepthOCCTRT, self).__init__(**kwargs)
self.wocc = wocc
self.wdet3d = wdet3d
self.uni_train = uni_train
def result_serialize(self, outs_det3d=None, outs_occ=None):
outs_ = []
if outs_det3d is not None:
for out in outs_det3d:
for key in ['reg', 'height', 'dim', 'rot', 'vel', 'heatmap']:
outs_.append(out[0][key])
if outs_occ is not None:
outs_.append(outs_occ)
return outs_
def result_deserialize(self, outs):
outs_ = []
keys = ['reg', 'height', 'dim', 'rot', 'vel', 'heatmap']
for head_id in range(len(outs) // 6):
outs_head = [dict()]
for kid, key in enumerate(keys):
outs_head[0][key] = outs[head_id * 6 + kid]
outs_.append(outs_head)
return outs_
def forward_ori(
self,
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
mlp_input,
):
x = self.img_backbone(img)
x = self.img_neck(x)
x = self.img_view_transformer.depth_net(x[0], mlp_input)
depth = x[:, :self.img_view_transformer.D].softmax(dim=1)
tran_feat = x[:, self.img_view_transformer.D:(
self.img_view_transformer.D +
self.img_view_transformer.out_channels)]
tran_feat = tran_feat.permute(0, 2, 3, 1)
x = TRTBEVPoolv2.apply(depth.contiguous(), tran_feat.contiguous(),
ranks_depth, ranks_feat, ranks_bev,
interval_starts, interval_lengths,
int(self.img_view_transformer.grid_size[0].item()),
int(self.img_view_transformer.grid_size[1].item()),
int(self.img_view_transformer.grid_size[2].item())
)
x = x.permute(0, 3, 1, 2).contiguous()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature = self.img_bev_encoder_backbone(x)
occ_bev_feature = self.img_bev_encoder_neck(bev_feature)
outs_occ = None
if self.wocc == True:
if self.uni_train == True:
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs_occ = self.occ_head(occ_bev_feature)
outs_det3d = None
if self.wdet3d == True:
outs_det3d = self.pts_bbox_head([occ_bev_feature])
outs = self.result_serialize(outs_det3d, outs_occ)
return outs
def forward_with_argmax(
self,
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
mlp_input,
):
outs = self.forward_ori(
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
mlp_input,
)
pred_occ_label = outs[0].argmax(-1)
return pred_occ_label
def get_bev_pool_input(self, input):
input = self.prepare_inputs(input)
coor = self.img_view_transformer.get_lidar_coor(*input[1:7])
mlp_input = self.img_view_transformer.get_mlp_input(*input[1:7])
# sensor2keyegos, ego2globals, intrins, post_rots, post_trans, bda) # (B, N_views, 27)
return self.img_view_transformer.voxel_pooling_prepare_v2(coor), mlp_input
@DETECTORS.register_module()
class BEVDepthPanoTRT(BEVDepthPano):
def __init__(self,
wocc=True,
wdet3d=True,
uni_train=True,
**kwargs):
super(BEVDepthPanoTRT, self).__init__(**kwargs)
self.wocc = wocc
self.wdet3d = wdet3d
self.uni_train = uni_train
def result_serialize(self, outs_det3d=None, outs_occ=None):
outs_ = []
if outs_det3d is not None:
for out in outs_det3d:
for key in ['reg', 'height', 'dim', 'rot', 'vel', 'heatmap']:
outs_.append(out[0][key])
if outs_occ is not None:
outs_.append(outs_occ)
return outs_
def result_deserialize(self, outs):
outs_ = []
keys = ['reg', 'height', 'dim', 'rot', 'vel', 'heatmap']
for head_id in range(len(outs) // 6):
outs_head = [dict()]
for kid, key in enumerate(keys):
outs_head[0][key] = outs[head_id * 6 + kid]
outs_.append(outs_head)
return outs_
def forward_part1(
self,
img,
mlp_input,
):
x = self.img_backbone(img)
x = self.img_neck(x)
x = self.img_view_transformer.depth_net(x[0], mlp_input)
depth = x[:, :self.img_view_transformer.D].softmax(dim=1)
tran_feat = x[:, self.img_view_transformer.D:(
self.img_view_transformer.D +
self.img_view_transformer.out_channels)]
tran_feat = tran_feat.permute(0, 2, 3, 1)
# depth = depth.reshape(-1)
# tran_feat = tran_feat.flatten(0,2)
return tran_feat.flatten(0,2), depth.reshape(-1)
def forward_part2(
self,
tran_feat,
depth,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
):
tran_feat = tran_feat.reshape(6, 16, 44, 64)
depth = depth.reshape(6, 16, 44, 44)
x = TRTBEVPoolv2.apply(depth.contiguous(), tran_feat.contiguous(),
ranks_depth, ranks_feat, ranks_bev,
interval_starts, interval_lengths,
int(self.img_view_transformer.grid_size[0].item()),
int(self.img_view_transformer.grid_size[1].item()),
int(self.img_view_transformer.grid_size[2].item())
) # -> [1, 64, 200, 200]
return x.reshape(-1)
def forward_part3(
self,
x
):
x = x.reshape(1, 200, 200, 64)
x = x.permute(0, 3, 1, 2).contiguous()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature = self.img_bev_encoder_backbone(x)
occ_bev_feature = self.img_bev_encoder_neck(bev_feature)
outs_occ = None
if self.wocc == True:
if self.uni_train == True:
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs_occ = self.occ_head(occ_bev_feature)
outs_det3d = None
if self.wdet3d == True:
outs_det3d = self.pts_bbox_head([occ_bev_feature])
outs = self.result_serialize(outs_det3d, outs_occ)
# outs_inst_center = self.aux_centerness_head([occ_bev_feature])
x = self.aux_centerness_head.shared_conv(occ_bev_feature) # (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
outs_inst_center_reg = self.aux_centerness_head.task_heads[0].reg(x)
outs.append(outs_inst_center_reg)
outs_inst_center_height = self.aux_centerness_head.task_heads[0].height(x)
outs.append(outs_inst_center_height)
outs_inst_center_heatmap = self.aux_centerness_head.task_heads[0].heatmap(x)
outs.append(outs_inst_center_heatmap)
def forward_ori(
self,
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
mlp_input,
):
x = self.img_backbone(img)
x = self.img_neck(x)
x = self.img_view_transformer.depth_net(x[0], mlp_input)
depth = x[:, :self.img_view_transformer.D].softmax(dim=1)
tran_feat = x[:, self.img_view_transformer.D:(
self.img_view_transformer.D +
self.img_view_transformer.out_channels)]
tran_feat = tran_feat.permute(0, 2, 3, 1)
x = TRTBEVPoolv2.apply(depth.contiguous(), tran_feat.contiguous(),
ranks_depth, ranks_feat, ranks_bev,
interval_starts, interval_lengths,
int(self.img_view_transformer.grid_size[0].item()),
int(self.img_view_transformer.grid_size[1].item()),
int(self.img_view_transformer.grid_size[2].item())
)
x = x.permute(0, 3, 1, 2).contiguous()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature = self.img_bev_encoder_backbone(x)
occ_bev_feature = self.img_bev_encoder_neck(bev_feature)
outs_occ = None
if self.wocc == True:
if self.uni_train == True:
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs_occ = self.occ_head(occ_bev_feature)
outs_det3d = None
if self.wdet3d == True:
outs_det3d = self.pts_bbox_head([occ_bev_feature])
outs = self.result_serialize(outs_det3d, outs_occ)
# outs_inst_center = self.aux_centerness_head([occ_bev_feature])
x = self.aux_centerness_head.shared_conv(occ_bev_feature) # (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
outs_inst_center_reg = self.aux_centerness_head.task_heads[0].reg(x)
outs.append(outs_inst_center_reg)
outs_inst_center_height = self.aux_centerness_head.task_heads[0].height(x)
outs.append(outs_inst_center_height)
outs_inst_center_heatmap = self.aux_centerness_head.task_heads[0].heatmap(x)
outs.append(outs_inst_center_heatmap)
return outs
def forward_with_argmax(
self,
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
mlp_input,
):
outs = self.forward_ori(
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
mlp_input,
)
pred_occ_label = outs[0].argmax(-1)
return pred_occ_label, *outs[1:]
def get_bev_pool_input(self, input):
input = self.prepare_inputs(input)
coor = self.img_view_transformer.get_lidar_coor(*input[1:7])
mlp_input = self.img_view_transformer.get_mlp_input(*input[1:7])
# sensor2keyegos, ego2globals, intrins, post_rots, post_trans, bda) # (B, N_views, 27)
return self.img_view_transformer.voxel_pooling_prepare_v2(coor), mlp_input
# Copyright (c) Phigent Robotics. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet3d.models import DETECTORS
from mmdet3d.models import builder
from .bevdepth4d import BEVDepth4D
from mmdet.models.backbones.resnet import ResNet
@DETECTORS.register_module()
class BEVStereo4D(BEVDepth4D):
def __init__(self, **kwargs):
super(BEVStereo4D, self).__init__(**kwargs)
self.extra_ref_frames = 1
self.temporal_frame = self.num_frame
self.num_frame += self.extra_ref_frames
def extract_stereo_ref_feat(self, x):
"""
Args:
x: (B, N_views, 3, H, W)
Returns:
x: (B*N_views, C_stereo, fH_stereo, fW_stereo)
"""
B, N, C, imH, imW = x.shape
x = x.view(B * N, C, imH, imW) # (B*N_views, 3, H, W)
if isinstance(self.img_backbone, ResNet):
if self.img_backbone.deep_stem:
x = self.img_backbone.stem(x)
else:
x = self.img_backbone.conv1(x)
x = self.img_backbone.norm1(x)
x = self.img_backbone.relu(x)
x = self.img_backbone.maxpool(x)
for i, layer_name in enumerate(self.img_backbone.res_layers):
res_layer = getattr(self.img_backbone, layer_name)
x = res_layer(x)
return x
else:
x = self.img_backbone.patch_embed(x)
hw_shape = (self.img_backbone.patch_embed.DH,
self.img_backbone.patch_embed.DW)
if self.img_backbone.use_abs_pos_embed:
x = x + self.img_backbone.absolute_pos_embed
x = self.img_backbone.drop_after_pos(x)
for i, stage in enumerate(self.img_backbone.stages):
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
out = out.view(-1, *out_hw_shape,
self.img_backbone.num_features[i])
out = out.permute(0, 3, 1, 2).contiguous()
return out
def prepare_bev_feat(self, img, sensor2keyego, ego2global, intrin,
post_rot, post_tran, bda, mlp_input, feat_prev_iv,
k2s_sensor, extra_ref_frame):
"""
Args:
img: (B, N_views, 3, H, W)
sensor2keyego: (B, N_views, 4, 4)
ego2global: (B, N_views, 4, 4)
intrin: (B, N_views, 3, 3)
post_rot: (B, N_views, 3, 3)
post_tran: (B, N_views, 3)
bda: (B, 3, 3)
mlp_input: (B, N_views, 27)
feat_prev_iv: (B*N_views, C_stereo, fH_stereo, fW_stereo) or None
k2s_sensor: (B, N_views, 4, 4) or None
extra_ref_frame:
Returns:
bev_feat: (B, C, Dy, Dx)
depth: (B*N, D, fH, fW)
stereo_feat: (B*N_views, C_stereo, fH_stereo, fW_stereo)
"""
if extra_ref_frame:
stereo_feat = self.extract_stereo_ref_feat(img) # (B*N_views, C_stereo, fH_stereo, fW_stereo)
return None, None, stereo_feat
# x: (B, N_views, C, fH, fW)
# stereo_feat: (B*N, C_stereo, fH_stereo, fW_stereo)
x, stereo_feat = self.image_encoder(img, stereo=True)
# 建立cost volume 所需的信息.
metas = dict(k2s_sensor=k2s_sensor, # (B, N_views, 4, 4)
intrins=intrin, # (B, N_views, 3, 3)
post_rots=post_rot, # (B, N_views, 3, 3)
post_trans=post_tran, # (B, N_views, 3)
frustum=self.img_view_transformer.cv_frustum.to(x), # (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample=4,
downsample=self.img_view_transformer.downsample,
grid_config=self.img_view_transformer.grid_config,
cv_feat_list=[feat_prev_iv, stereo_feat]
)
# bev_feat: (B, C * Dz(=1), Dy, Dx)
# depth: (B * N, D, fH, fW)
bev_feat, depth = self.img_view_transformer(
[x, sensor2keyego, ego2global, intrin, post_rot, post_tran, bda,
mlp_input], metas)
if self.pre_process:
bev_feat = self.pre_process_net(bev_feat)[0] # (B, C, Dy, Dx)
return bev_feat, depth, stereo_feat
def extract_img_feat_sequential(self, inputs, feat_prev):
"""
Args:
inputs:
curr_img: (1, N_views, 3, H, W)
sensor2keyegos_curr: (N_prev, N_views, 4, 4)
ego2globals_curr: (N_prev, N_views, 4, 4)
intrins: (1, N_views, 3, 3)
sensor2keyegos_prev: (N_prev, N_views, 4, 4)
ego2globals_prev: (N_prev, N_views, 4, 4)
post_rots: (1, N_views, 3, 3)
post_trans: (1, N_views, 3, )
bda_curr: (N_prev, 3, 3)
feat_prev_iv:
curr2adjsensor: (1, N_views, 4, 4)
feat_prev: (N_prev, C, Dy, Dx)
Returns:
"""
imgs, sensor2keyegos_curr, ego2globals_curr, intrins = inputs[:4]
sensor2keyegos_prev, _, post_rots, post_trans, bda = inputs[4:9]
feat_prev_iv, curr2adjsensor = inputs[9:]
bev_feat_list = []
mlp_input = self.img_view_transformer.get_mlp_input(
sensor2keyegos_curr[0:1, ...], ego2globals_curr[0:1, ...],
intrins, post_rots, post_trans, bda[0:1, ...])
inputs_curr = (imgs, sensor2keyegos_curr[0:1, ...],
ego2globals_curr[0:1, ...], intrins, post_rots,
post_trans, bda[0:1, ...], mlp_input, feat_prev_iv,
curr2adjsensor, False)
# (1, C, Dx, Dy), (1*N, D, fH, fW)
bev_feat, depth, _ = self.prepare_bev_feat(*inputs_curr)
bev_feat_list.append(bev_feat)
# align the feat_prev
_, C, H, W = feat_prev.shape
# feat_prev: (N_prev, C, Dy, Dx)
feat_prev = \
self.shift_feature(feat_prev, # (N_prev, C, Dy, Dx)
[sensor2keyegos_curr, # (N_prev, N_views, 4, 4)
sensor2keyegos_prev], # (N_prev, N_views, 4, 4)
bda # (N_prev, 3, 3)
)
bev_feat_list.append(feat_prev.view(1, (self.num_frame - 2) * C, H, W)) # (1, N_prev*C, Dy, Dx)
bev_feat = torch.cat(bev_feat_list, dim=1) # (1, N_frames*C, Dy, Dx)
x = self.bev_encoder(bev_feat)
return [x], depth
def extract_img_feat(self,
img_inputs,
img_metas,
pred_prev=False,
sequential=False,
**kwargs):
"""
Args:
img_inputs:
imgs: (B, N, 3, H, W) # N = 6 * (N_history + 1)
sensor2egos: (B, N, 4, 4)
ego2globals: (B, N, 4, 4)
intrins: (B, N, 3, 3)
post_rots: (B, N, 3, 3)
post_trans: (B, N, 3)
bda_rot: (B, 3, 3)
img_metas:
**kwargs:
Returns:
x: [(B, C', H', W'), ]
depth: (B*N_views, D, fH, fW)
"""
if sequential:
return self.extract_img_feat_sequential(img_inputs, kwargs['feat_prev'])
imgs, sensor2keyegos, ego2globals, intrins, post_rots, post_trans, \
bda, curr2adjsensor = self.prepare_inputs(img_inputs, stereo=True)
"""Extract features of images."""
bev_feat_list = []
depth_key_frame = None
feat_prev_iv = None
for fid in range(self.num_frame-1, -1, -1):
img, sensor2keyego, ego2global, intrin, post_rot, post_tran = \
imgs[fid], sensor2keyegos[fid], ego2globals[fid], intrins[fid], \
post_rots[fid], post_trans[fid]
key_frame = fid == 0
extra_ref_frame = fid == self.num_frame-self.extra_ref_frames
if key_frame or self.with_prev:
if self.align_after_view_transfromation:
sensor2keyego, ego2global = sensor2keyegos[0], ego2globals[0]
mlp_input = self.img_view_transformer.get_mlp_input(
sensor2keyegos[0], ego2globals[0], intrin,
post_rot, post_tran, bda) # (B, N_views, 27)
inputs_curr = (img, sensor2keyego, ego2global, intrin,
post_rot, post_tran, bda, mlp_input,
feat_prev_iv, curr2adjsensor[fid],
extra_ref_frame)
if key_frame:
bev_feat, depth, feat_curr_iv = \
self.prepare_bev_feat(*inputs_curr)
depth_key_frame = depth
else:
with torch.no_grad():
bev_feat, depth, feat_curr_iv = \
self.prepare_bev_feat(*inputs_curr)
if not extra_ref_frame:
bev_feat_list.append(bev_feat)
if not key_frame:
feat_prev_iv = feat_curr_iv
if pred_prev:
assert self.align_after_view_transfromation
assert sensor2keyegos[0].shape[0] == 1 # batch_size = 1
feat_prev = torch.cat(bev_feat_list[1:], dim=0)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
ego2globals_curr = \
ego2globals[0].repeat(self.num_frame - 2, 1, 1, 1)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
sensor2keyegos_curr = \
sensor2keyegos[0].repeat(self.num_frame - 2, 1, 1, 1)
ego2globals_prev = torch.cat(ego2globals[1:-1], dim=0) # (N_prev, N_views, 4, 4)
sensor2keyegos_prev = torch.cat(sensor2keyegos[1:-1], dim=0) # (N_prev, N_views, 4, 4)
bda_curr = bda.repeat(self.num_frame - 2, 1, 1) # (N_prev, 3, 3)
return feat_prev, [imgs[0], # (1, N_views, 3, H, W)
sensor2keyegos_curr, # (N_prev, N_views, 4, 4)
ego2globals_curr, # (N_prev, N_views, 4, 4)
intrins[0], # (1, N_views, 3, 3)
sensor2keyegos_prev, # (N_prev, N_views, 4, 4)
ego2globals_prev, # (N_prev, N_views, 4, 4)
post_rots[0], # (1, N_views, 3, 3)
post_trans[0], # (1, N_views, 3, )
bda_curr, # (N_prev, 3, 3)
feat_prev_iv,
curr2adjsensor[0]]
if not self.with_prev:
bev_feat_key = bev_feat_list[0]
if len(bev_feat_key.shape) == 4:
b, c, h, w = bev_feat_key.shape
bev_feat_list = \
[torch.zeros([b,
c * (self.num_frame -
self.extra_ref_frames - 1),
h, w]).to(bev_feat_key), bev_feat_key]
else:
b, c, z, h, w = bev_feat_key.shape
bev_feat_list = \
[torch.zeros([b,
c * (self.num_frame -
self.extra_ref_frames - 1), z,
h, w]).to(bev_feat_key), bev_feat_key]
if self.align_after_view_transfromation:
for adj_id in range(self.num_frame-2):
bev_feat_list[adj_id] = self.shift_feature(
bev_feat_list[adj_id], # (B, C, Dy, Dx)
[sensor2keyegos[0], # (B, N_views, 4, 4)
sensor2keyegos[self.num_frame-2-adj_id]], # (B, N_views, 4, 4)
bda # (B, 3, 3)
) # (B, C, Dy, Dx)
bev_feat = torch.cat(bev_feat_list, dim=1)
x = self.bev_encoder(bev_feat)
return [x], depth_key_frame
from .cross_entropy_loss import CrossEntropyLoss
from .focal_loss import CustomFocalLoss
__all__ = ['CrossEntropyLoss', 'CustomFocalLoss']
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weight_reduce_loss
def cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index
# element-wise losses
loss = F.cross_entropy(
pred,
label,
weight=class_weight,
reduction='none',
ignore_index=ignore_index)
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(
valid_mask & (labels < label_channels), as_tuple=False)
if inds.numel() > 0:
bin_labels[inds, labels[inds]] = 1
valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
label_channels).float()
if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
bin_label_weights *= valid_mask
return bin_labels, bin_label_weights, valid_mask
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
When the shape of pred is (N, 1), label will be expanded to
one-hot format, and when the shape of pred is (N, ), label
will not be expanded to one-hot format.
label (torch.Tensor): The learning label of the prediction,
with shape (N, ).
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss.
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index
if pred.dim() != label.dim():
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.size(-1), ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
# The inplace writing method will have a mismatched broadcast
# shape error if the weight and valid_mask dimensions
# are inconsistent such as (B,N,1) and (B,N,C).
weight = weight * valid_mask
else:
weight = valid_mask
# average loss over non-ignored elements
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = valid_mask.sum().item()
# weighted element-wise losses
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(
loss, weight, reduction=reduction, avg_factor=avg_factor)
return loss
def mask_cross_entropy(pred,
target,
label,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None,
**kwargs):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C, *), C is the
number of classes. The trailing * indicates arbitrary shape.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
Example:
>>> N, C = 3, 11
>>> H, W = 2, 2
>>> pred = torch.randn(N, C, H, W) * 1000
>>> target = torch.rand(N, H, W)
>>> label = torch.randint(0, C, size=(N,))
>>> reduction = 'mean'
>>> avg_factor = None
>>> class_weights = None
>>> loss = mask_cross_entropy(pred, target, label, reduction,
>>> avg_factor, class_weights)
>>> assert loss.shape == (1,)
"""
assert ignore_index is None, 'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, weight=class_weight, reduction='mean')[None]
@LOSSES.register_module(force=True)
class CrossEntropyLoss(nn.Module):
def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
class_weight=None,
ignore_index=None,
loss_weight=1.0,
avg_non_ignore=False):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
ignore_index (int | None): The label index to be ignored.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
"""
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.ignore_index = ignore_index
self.avg_non_ignore = avg_non_ignore
if ((ignore_index is not None) and not self.avg_non_ignore
and self.reduction == 'mean'):
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
def extra_repr(self):
"""Extra repr."""
s = f'avg_non_ignore={self.avg_non_ignore}'
return s
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=None,
**kwargs):
"""Forward function.
Args:
cls_score (torch.Tensor): The prediction.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The method used to reduce the
loss. Options are "none", "mean" and "sum".
ignore_index (int | None): The label index to be ignored.
If not None, it will override the default value. Default: None.
Returns:
torch.Tensor: The calculated loss.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if ignore_index is None:
ignore_index = self.ignore_index
if self.class_weight is not None:
class_weight = cls_score.new_tensor(
self.class_weight, device=cls_score.device)
else:
class_weight = None
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
ignore_index=ignore_index,
avg_non_ignore=self.avg_non_ignore,
**kwargs)
return loss_cls
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weight_reduce_loss
import numpy as np
# This method is only for debugging
def py_sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = loss * weight
loss = loss.sum(-1).mean()
# loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
def py_focal_loss_with_prob(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Different from `py_sigmoid_focal_loss`, this function accepts probability
as input.
Args:
pred (torch.Tensor): The prediction probability with shape (N, C),
C is the number of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
target = target[:, :num_classes]
target = target.type_as(pred)
pt = (1 - pred) * target + pred * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy(
pred, target, reduction='none') * focal_weight
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
def sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
r"""A wrapper of cuda version `Focal Loss
<https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma,
alpha, None, 'none')
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = loss * weight
loss = loss.sum(-1).mean()
# loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module()
class CustomFocalLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=100.0,
activated=False):
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
activated (bool, optional): Whether the input is activated.
If True, it means the input has been activated and can be
treated as probabilities. Else, it should be treated as logits.
Defaults to False.
"""
super(CustomFocalLoss, self).__init__()
assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
self.use_sigmoid = use_sigmoid
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
self.activated = activated
H, W = 200, 200
xy, yx = torch.meshgrid([torch.arange(H) - H / 2, torch.arange(W) - W / 2])
c = torch.stack([xy, yx], 2)
c = torch.norm(c, 2, -1)
c_max = c.max()
self.c = (c / c_max + 1).cuda()
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
ignore_index=255,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
B, H, W, D = target.shape
c = self.c[None, :, :, None].repeat(B, 1, 1, D).reshape(-1)
visible_mask = (target != ignore_index).reshape(-1).nonzero().squeeze(-1)
weight_mask = weight[None, :] * c[visible_mask, None]
# visible_mask[:, None]
num_classes = pred.size(1)
pred = pred.permute(0, 2, 3, 4, 1).reshape(-1, num_classes)[visible_mask]
target = target.reshape(-1)[visible_mask]
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid:
if self.activated:
calculate_loss_func = py_focal_loss_with_prob
else:
if torch.cuda.is_available() and pred.is_cuda:
calculate_loss_func = sigmoid_focal_loss
else:
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
target = target[:, :num_classes]
calculate_loss_func = py_sigmoid_focal_loss
loss_cls = self.loss_weight * calculate_loss_func(
pred,
target.to(torch.long),
weight_mask,
gamma=self.gamma,
alpha=self.alpha,
reduction=reduction,
avg_factor=avg_factor)
else:
raise NotImplementedError
return loss_cls
# -*- coding:utf-8 -*-
# author: Xinge
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""
from __future__ import print_function, division
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
try:
from itertools import ifilterfalse
except ImportError: # py3k
from itertools import filterfalse as ifilterfalse
from torch.cuda.amp import autocast
def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
"""
IoU for foreground class
binary: 1 foreground, 0 background
"""
if not per_image:
preds, labels = (preds,), (labels,)
ious = []
for pred, label in zip(preds, labels):
intersection = ((label == 1) & (pred == 1)).sum()
union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
if not union:
iou = EMPTY
else:
iou = float(intersection) / float(union)
ious.append(iou)
iou = mean(ious) # mean accross images if per_image
return 100 * iou
def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
"""
Array of IoU for each (non ignored) class
"""
if not per_image:
preds, labels = (preds,), (labels,)
ious = []
for pred, label in zip(preds, labels):
iou = []
for i in range(C):
if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
intersection = ((label == i) & (pred == i)).sum()
union = ((label == i) | ((pred == i) & (label != ignore))).sum()
if not union:
iou.append(EMPTY)
else:
iou.append(float(intersection) / float(union))
ious.append(iou)
ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
return 100 * np.array(ious)
# --------------------------- BINARY LOSSES ---------------------------
def lovasz_hinge(logits, labels, per_image=True, ignore=None):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
for log, lab in zip(logits, labels))
else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
return loss
def lovasz_hinge_flat(logits, labels):
"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * Variable(signs))
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
return loss
def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = scores.view(-1)
labels = labels.view(-1)
if ignore is None:
return scores, labels
valid = (labels != ignore)
vscores = scores[valid]
vlabels = labels[valid]
return vscores, vlabels
class StableBCELoss(torch.nn.modules.Module):
def __init__(self):
super(StableBCELoss, self).__init__()
def forward(self, input, target):
neg_abs = - input.abs()
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
return loss.mean()
def binary_xloss(logits, labels, ignore=None):
"""
Binary Cross entropy loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
ignore: void class id
"""
logits, labels = flatten_binary_scores(logits, labels, ignore)
loss = StableBCELoss()(logits, Variable(labels.float()))
return loss
# --------------------------- MULTICLASS LOSSES ---------------------------
def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
"""
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
if per_image:
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
for prob, lab in zip(probas, labels))
else:
with autocast(False):
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
return loss
def lovasz_softmax_flat(probas, labels, classes='present'):
"""
Multi-class Lovasz-Softmax loss
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
if probas.numel() == 0:
# only void pixels, the gradients should be 0
return probas * 0.
C = probas.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
fg = (labels == c).float() # foreground for class c
if (classes is 'present' and fg.sum() == 0):
continue
if C == 1:
if len(classes) > 1:
raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probas[:, 0]
else:
class_pred = probas[:, c]
errors = (Variable(fg) - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
return mean(losses)
def flatten_probas(probas, labels, ignore=None):
"""
Flattens predictions in the batch
"""
if probas.dim() == 2:
if ignore is not None:
valid = (labels != ignore)
probas = probas[valid]
labels = labels[valid]
return probas, labels
elif probas.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probas.size()
probas = probas.view(B, 1, H, W)
elif probas.dim() == 5:
#3D segmentation
B, C, L, H, W = probas.size()
probas = probas.contiguous().view(B, C, L, H*W)
B, C, H, W = probas.size()
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
labels = labels.view(-1)
if ignore is None:
return probas, labels
valid = (labels != ignore)
vprobas = probas[valid.nonzero().squeeze()]
vlabels = labels[valid]
return vprobas, vlabels
def xloss(logits, labels, ignore=None):
"""
Cross entropy loss
"""
return F.cross_entropy(logits, Variable(labels), ignore_index=255)
def jaccard_loss(probas, labels,ignore=None, smooth = 100, bk_class = None):
"""
Something wrong with this loss
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
vprobas, vlabels = flatten_probas(probas, labels, ignore)
true_1_hot = torch.eye(vprobas.shape[1])[vlabels]
if bk_class:
one_hot_assignment = torch.ones_like(vlabels)
one_hot_assignment[vlabels == bk_class] = 0
one_hot_assignment = one_hot_assignment.float().unsqueeze(1)
true_1_hot = true_1_hot*one_hot_assignment
true_1_hot = true_1_hot.to(vprobas.device)
intersection = torch.sum(vprobas * true_1_hot)
cardinality = torch.sum(vprobas + true_1_hot)
loss = (intersection + smooth / (cardinality - intersection + smooth)).mean()
return (1-loss)*smooth
def hinge_jaccard_loss(probas, labels,ignore=None, classes = 'present', hinge = 0.1, smooth =100):
"""
Multi-class Hinge Jaccard loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
ignore: void class labels
"""
vprobas, vlabels = flatten_probas(probas, labels, ignore)
C = vprobas.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
if c in vlabels:
c_sample_ind = vlabels == c
cprobas = vprobas[c_sample_ind,:]
non_c_ind =np.array([a for a in class_to_sum if a != c])
class_pred = cprobas[:,c]
max_non_class_pred = torch.max(cprobas[:,non_c_ind],dim = 1)[0]
TP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) + smooth
FN = torch.sum(torch.clamp(max_non_class_pred - class_pred, min = -hinge)+hinge)
if (~c_sample_ind).sum() == 0:
FP = 0
else:
nonc_probas = vprobas[~c_sample_ind,:]
class_pred = nonc_probas[:,c]
max_non_class_pred = torch.max(nonc_probas[:,non_c_ind],dim = 1)[0]
FP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.)
losses.append(1 - TP/(TP+FP+FN))
if len(losses) == 0: return 0
return mean(losses)
# --------------------------- HELPER FUNCTIONS ---------------------------
def isnan(x):
return x != x
def mean(l, ignore_nan=False, empty=0):
"""
nanmean compatible with generators.
"""
l = iter(l)
if ignore_nan:
l = ifilterfalse(isnan, l)
try:
n = 1
acc = next(l)
except StopIteration:
if empty == 'raise':
raise ValueError('Empty mean')
return empty
for n, v in enumerate(l, 2):
acc += v
if n == 1:
return acc
return acc / n
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# from mmcv.runner import BaseModule, force_fp32
from torch.cuda.amp import autocast
semantic_kitti_class_frequencies = np.array(
[
5.41773033e09,
1.57835390e07,
1.25136000e05,
1.18809000e05,
6.46799000e05,
8.21951000e05,
2.62978000e05,
2.83696000e05,
2.04750000e05,
6.16887030e07,
4.50296100e06,
4.48836500e07,
2.26992300e06,
5.68402180e07,
1.57196520e07,
1.58442623e08,
2.06162300e06,
3.69705220e07,
1.15198800e06,
3.34146000e05,
]
)
kitti_class_names = [
"empty",
"car",
"bicycle",
"motorcycle",
"truck",
"other-vehicle",
"person",
"bicyclist",
"motorcyclist",
"road",
"parking",
"sidewalk",
"other-ground",
"building",
"fence",
"vegetation",
"trunk",
"terrain",
"pole",
"traffic-sign",
]
def inverse_sigmoid(x, sign='A'):
x = x.to(torch.float32)
while x >= 1-1e-5:
x = x - 1e-5
while x< 1e-5:
x = x + 1e-5
return -torch.log((1 / x) - 1)
def KL_sep(p, target):
"""
KL divergence on nonzeros classes
"""
nonzeros = target != 0
nonzero_p = p[nonzeros]
kl_term = F.kl_div(torch.log(nonzero_p), target[nonzeros], reduction="sum")
return kl_term
def geo_scal_loss(pred, ssc_target, ignore_index=255, non_empty_idx=0):
# Get softmax probabilities
pred = F.softmax(pred, dim=1)
# Compute empty and nonempty probabilities
empty_probs = pred[:, non_empty_idx]
nonempty_probs = 1 - empty_probs
# Remove unknown voxels
mask = ssc_target != ignore_index
nonempty_target = ssc_target != non_empty_idx
nonempty_target = nonempty_target[mask].float()
nonempty_probs = nonempty_probs[mask]
empty_probs = empty_probs[mask]
eps = 1e-5
intersection = (nonempty_target * nonempty_probs).sum()
precision = intersection / (nonempty_probs.sum()+eps)
recall = intersection / (nonempty_target.sum()+eps)
spec = ((1 - nonempty_target) * (empty_probs)).sum() / ((1 - nonempty_target).sum()+eps)
with autocast(False):
return (
F.binary_cross_entropy_with_logits(inverse_sigmoid(precision, 'A'), torch.ones_like(precision))
+ F.binary_cross_entropy_with_logits(inverse_sigmoid(recall, 'B'), torch.ones_like(recall))
+ F.binary_cross_entropy_with_logits(inverse_sigmoid(spec, 'C'), torch.ones_like(spec))
)
def sem_scal_loss(pred_, ssc_target, ignore_index=255):
# Get softmax probabilities
with autocast(False):
pred = F.softmax(pred_, dim=1) # (B, n_class, Dx, Dy, Dz)
loss = 0
count = 0
mask = ssc_target != ignore_index
n_classes = pred.shape[1]
begin = 0
for i in range(begin, n_classes-1):
# Get probability of class i
p = pred[:, i] # (B, Dx, Dy, Dz)
# Remove unknown voxels
target_ori = ssc_target # (B, Dx, Dy, Dz)
p = p[mask]
target = ssc_target[mask]
completion_target = torch.ones_like(target)
completion_target[target != i] = 0
completion_target_ori = torch.ones_like(target_ori).float()
completion_target_ori[target_ori != i] = 0
if torch.sum(completion_target) > 0:
count += 1.0
nominator = torch.sum(p * completion_target)
loss_class = 0
if torch.sum(p) > 0:
precision = nominator / (torch.sum(p)+ 1e-5)
loss_precision = F.binary_cross_entropy_with_logits(
inverse_sigmoid(precision, 'D'), torch.ones_like(precision)
)
loss_class += loss_precision
if torch.sum(completion_target) > 0:
recall = nominator / (torch.sum(completion_target) +1e-5)
# loss_recall = F.binary_cross_entropy(recall, torch.ones_like(recall))
loss_recall = F.binary_cross_entropy_with_logits(inverse_sigmoid(recall, 'E'), torch.ones_like(recall))
loss_class += loss_recall
if torch.sum(1 - completion_target) > 0:
specificity = torch.sum((1 - p) * (1 - completion_target)) / (
torch.sum(1 - completion_target) + 1e-5
)
loss_specificity = F.binary_cross_entropy_with_logits(
inverse_sigmoid(specificity, 'F'), torch.ones_like(specificity)
)
loss_class += loss_specificity
loss += loss_class
# print(i, loss_class, loss_recall, loss_specificity)
l = loss/count
if torch.isnan(l):
from IPython import embed
embed()
exit()
return l
def CE_ssc_loss(pred, target, class_weights=None, ignore_index=255):
"""
:param: prediction: the predicted tensor, must be [BS, C, ...]
"""
criterion = nn.CrossEntropyLoss(
weight=class_weights, ignore_index=ignore_index, reduction="mean"
)
# from IPython import embed
# embed()
# exit()
with autocast(False):
loss = criterion(pred, target.long())
return loss
def vel_loss(pred, gt):
with autocast(False):
return F.l1_loss(pred, gt)
from .depthnet import DepthNet
__all__ = ['DepthNet']
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models.backbones.resnet import BasicBlock
from mmcv.cnn import build_conv_layer
from torch.cuda.amp.autocast_mode import autocast
from torch.utils.checkpoint import checkpoint
class _ASPPModule(nn.Module):
def __init__(self, inplanes, planes, kernel_size, padding, dilation,
BatchNorm):
super(_ASPPModule, self).__init__()
self.atrous_conv = nn.Conv2d(
inplanes,
planes,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=False)
self.bn = BatchNorm(planes)
self.relu = nn.ReLU()
self._init_weight()
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class ASPP(nn.Module):
def __init__(self, inplanes, mid_channels=256, BatchNorm=nn.BatchNorm2d):
super(ASPP, self).__init__()
dilations = [1, 6, 12, 18]
self.aspp1 = _ASPPModule(
inplanes,
mid_channels,
1,
padding=0,
dilation=dilations[0],
BatchNorm=BatchNorm)
self.aspp2 = _ASPPModule(
inplanes,
mid_channels,
3,
padding=dilations[1],
dilation=dilations[1],
BatchNorm=BatchNorm)
self.aspp3 = _ASPPModule(
inplanes,
mid_channels,
3,
padding=dilations[2],
dilation=dilations[2],
BatchNorm=BatchNorm)
self.aspp4 = _ASPPModule(
inplanes,
mid_channels,
3,
padding=dilations[3],
dilation=dilations[3],
BatchNorm=BatchNorm)
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(inplanes, mid_channels, 1, stride=1, bias=False),
BatchNorm(mid_channels),
nn.ReLU(),
)
self.conv1 = nn.Conv2d(
int(mid_channels * 5), inplanes, 1, bias=False)
self.bn1 = BatchNorm(inplanes)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self._init_weight()
def forward(self, x):
"""
Args:
x: (B*N, C, fH, fW)
Returns:
x: (B*N, C, fH, fW)
"""
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.interpolate(
x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1) # (B*N, 5*C', fH, fW)
x = self.conv1(x) # (B*N, C, fH, fW)
x = self.bn1(x)
x = self.relu(x)
return self.dropout(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.ReLU,
drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
"""
Args:
x: (B*N_views, 27)
Returns:
x: (B*N_views, C)
"""
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class SELayer(nn.Module):
def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid):
super().__init__()
self.conv_reduce = nn.Conv2d(channels, channels, 1, bias=True)
self.act1 = act_layer()
self.conv_expand = nn.Conv2d(channels, channels, 1, bias=True)
self.gate = gate_layer()
def forward(self, x, x_se):
"""
Args:
x: (B*N_views, C_mid, fH, fW)
x_se: (B*N_views, C_mid, 1, 1)
Returns:
x: (B*N_views, C_mid, fH, fW)
"""
x_se = self.conv_reduce(x_se) # (B*N_views, C_mid, 1, 1)
x_se = self.act1(x_se) # (B*N_views, C_mid, 1, 1)
x_se = self.conv_expand(x_se) # (B*N_views, C_mid, 1, 1)
return x * self.gate(x_se) # (B*N_views, C_mid, fH, fW)
class DepthNet(nn.Module):
def __init__(self,
in_channels,
mid_channels,
context_channels,
depth_channels,
use_dcn=True,
use_aspp=True,
with_cp=False,
stereo=False,
bias=0.0,
aspp_mid_channels=-1):
super(DepthNet, self).__init__()
self.reduce_conv = nn.Sequential(
nn.Conv2d(
in_channels, mid_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
# 生成context feature
self.context_conv = nn.Conv2d(
mid_channels, context_channels, kernel_size=1, stride=1, padding=0)
self.bn = nn.BatchNorm1d(27)
self.depth_mlp = Mlp(in_features=27, hidden_features=mid_channels, out_features=mid_channels)
self.depth_se = SELayer(channels=mid_channels) # NOTE: add camera-aware
self.context_mlp = Mlp(in_features=27, hidden_features=mid_channels, out_features=mid_channels)
self.context_se = SELayer(channels=mid_channels) # NOTE: add camera-aware
depth_conv_input_channels = mid_channels
downsample = None
if stereo:
depth_conv_input_channels += depth_channels
downsample = nn.Conv2d(depth_conv_input_channels,
mid_channels, 1, 1, 0)
cost_volumn_net = []
for stage in range(int(2)):
cost_volumn_net.extend([
nn.Conv2d(depth_channels, depth_channels, kernel_size=3,
stride=2, padding=1),
nn.BatchNorm2d(depth_channels)])
self.cost_volumn_net = nn.Sequential(*cost_volumn_net)
self.bias = bias
# 3个残差blocks
depth_conv_list = [BasicBlock(depth_conv_input_channels, mid_channels,
downsample=downsample),
BasicBlock(mid_channels, mid_channels),
BasicBlock(mid_channels, mid_channels)]
if use_aspp:
if aspp_mid_channels < 0:
aspp_mid_channels = mid_channels
depth_conv_list.append(ASPP(mid_channels, aspp_mid_channels))
if use_dcn:
depth_conv_list.append(
build_conv_layer(
cfg=dict(
type='DCN',
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
padding=1,
groups=4,
im2col_step=128,
)))
depth_conv_list.append(
nn.Conv2d(
mid_channels,
depth_channels,
kernel_size=1,
stride=1,
padding=0))
self.depth_conv = nn.Sequential(*depth_conv_list)
self.with_cp = with_cp
self.depth_channels = depth_channels
# ----------------------------------------- 用于建立cost volume ----------------------------------
def gen_grid(self, metas, B, N, D, H, W, hi, wi):
"""
Args:
metas: dict{
k2s_sensor: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
frustum: (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample: 4,
downsample: self.img_view_transformer.downsample=16,
grid_config: self.img_view_transformer.grid_config,
cv_feat_list: [feat_prev_iv, stereo_feat]
}
B: batchsize
N: N_views
D: D
H: fH_stereo
W: fW_stereo
hi: H_img
wi: W_img
Returns:
grid: (B*N_views, D*fH_stereo, fW_stereo, 2)
"""
frustum = metas['frustum'] # (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
# 逆图像增广:
points = frustum - metas['post_trans'].view(B, N, 1, 1, 1, 3)
points = torch.inverse(metas['post_rots']).view(B, N, 1, 1, 1, 3, 3) \
.matmul(points.unsqueeze(-1)) # (B, N_views, D, fH_stereo, fW_stereo, 3, 1)
# (u, v, d) --> (du, dv, d)
# (B, N_views, D, fH_stereo, fW_stereo, 3, 1)
points = torch.cat(
(points[..., :2, :] * points[..., 2:3, :], points[..., 2:3, :]), 5)
# cur_pixel --> curr_camera --> prev_camera
rots = metas['k2s_sensor'][:, :, :3, :3].contiguous()
trans = metas['k2s_sensor'][:, :, :3, 3].contiguous()
combine = rots.matmul(torch.inverse(metas['intrins']))
points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points)
points += trans.view(B, N, 1, 1, 1, 3, 1) # (B, N_views, D, fH_stereo, fW_stereo, 3, 1)
neg_mask = points[..., 2, 0] < 1e-3
# prev_camera --> prev_pixel
points = metas['intrins'].view(B, N, 1, 1, 1, 3, 3).matmul(points)
# (du, dv, d) --> (u, v) (B, N_views, D, fH_stereo, fW_stereo, 2, 1)
points = points[..., :2, :] / points[..., 2:3, :]
# 图像增广
points = metas['post_rots'][..., :2, :2].view(B, N, 1, 1, 1, 2, 2).matmul(
points).squeeze(-1)
points += metas['post_trans'][..., :2].view(B, N, 1, 1, 1, 2) # (B, N_views, D, fH_stereo, fW_stereo, 2)
px = points[..., 0] / (wi - 1.0) * 2.0 - 1.0
py = points[..., 1] / (hi - 1.0) * 2.0 - 1.0
px[neg_mask] = -2
py[neg_mask] = -2
grid = torch.stack([px, py], dim=-1) # (B, N_views, D, fH_stereo, fW_stereo, 2)
grid = grid.view(B * N, D * H, W, 2) # (B*N_views, D*fH_stereo, fW_stereo, 2)
return grid
def calculate_cost_volumn(self, metas):
"""
Args:
metas: dict{
k2s_sensor: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
frustum: (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample: 4,
downsample: self.img_view_transformer.downsample=16,
grid_config: self.img_view_transformer.grid_config,
cv_feat_list: [feat_prev_iv, stereo_feat]
}
Returns:
cost_volumn: (B*N_views, D, fH_stereo, fW_stereo)
"""
prev, curr = metas['cv_feat_list'] # (B*N_views, C_stereo, fH_stereo, fW_stereo)
group_size = 4
_, c, hf, wf = curr.shape #
hi, wi = hf * 4, wf * 4 # H_img, W_img
B, N, _ = metas['post_trans'].shape
D, H, W, _ = metas['frustum'].shape
grid = self.gen_grid(metas, B, N, D, H, W, hi, wi).to(curr.dtype) # (B*N_views, D*fH_stereo, fW_stereo, 2)
prev = prev.view(B * N, -1, H, W) # (B*N_views, C_stereo, fH_stereo, fW_stereo)
curr = curr.view(B * N, -1, H, W) # (B*N_views, C_stereo, fH_stereo, fW_stereo)
cost_volumn = 0
# process in group wise to save memory
for fid in range(curr.shape[1] // group_size):
# (B*N_views, group_size, fH_stereo, fW_stereo)
prev_curr = prev[:, fid * group_size:(fid + 1) * group_size, ...]
wrap_prev = F.grid_sample(prev_curr, grid,
align_corners=True,
padding_mode='zeros') # (B*N_views, group_size, D*fH_stereo, fW_stereo)
# (B*N_views, group_size, fH_stereo, fW_stereo)
curr_tmp = curr[:, fid * group_size:(fid + 1) * group_size, ...]
# (B*N_views, group_size, 1, fH_stereo, fW_stereo) - (B*N_views, group_size, D, fH_stereo, fW_stereo)
# --> (B*N_views, group_size, D, fH_stereo, fW_stereo)
# https://github.com/HuangJunJie2017/BEVDet/issues/278
cost_volumn_tmp = curr_tmp.unsqueeze(2) - \
wrap_prev.view(B * N, -1, D, H, W)
cost_volumn_tmp = cost_volumn_tmp.abs().sum(dim=1) # (B*N_views, D, fH_stereo, fW_stereo)
cost_volumn += cost_volumn_tmp # (B*N_views, D, fH_stereo, fW_stereo)
if not self.bias == 0:
invalid = wrap_prev[:, 0, ...].view(B * N, D, H, W) == 0
cost_volumn[invalid] = cost_volumn[invalid] + self.bias
# matching cost --> prob
cost_volumn = - cost_volumn
cost_volumn = cost_volumn.softmax(dim=1)
return cost_volumn
# ----------------------------------------- 用于建立cost volume --------------------------------------
def forward(self, x, mlp_input, stereo_metas=None):
"""
Args:
x: (B*N_views, C, fH, fW)
mlp_input: (B, N_views, 27)
stereo_metas: None or dict{
k2s_sensor: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
frustum: (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample: 4,
downsample: self.img_view_transformer.downsample=16,
grid_config: self.img_view_transformer.grid_config,
cv_feat_list: [feat_prev_iv, stereo_feat]
}
Returns:
x: (B*N_views, D+C_context, fH, fW)
"""
mlp_input = self.bn(mlp_input.reshape(-1, mlp_input.shape[-1])) # (B*N_views, 27)
x = self.reduce_conv(x) # (B*N_views, C_mid, fH, fW)
# (B*N_views, 27) --> (B*N_views, C_mid) --> (B*N_views, C_mid, 1, 1)
context_se = self.context_mlp(mlp_input)[..., None, None]
context = self.context_se(x, context_se) # (B*N_views, C_mid, fH, fW)
context = self.context_conv(context) # (B*N_views, C_context, fH, fW)
# (B*N_views, 27) --> (B*N_views, C_mid) --> (B*N_views, C_mid, 1, 1)
depth_se = self.depth_mlp(mlp_input)[..., None, None]
depth = self.depth_se(x, depth_se) # (B*N_views, C_mid, fH, fW)
if not stereo_metas is None:
if stereo_metas['cv_feat_list'][0] is None:
BN, _, H, W = x.shape
scale_factor = float(stereo_metas['downsample'])/\
stereo_metas['cv_downsample']
cost_volumn = \
torch.zeros((BN, self.depth_channels,
int(H*scale_factor),
int(W*scale_factor))).to(x)
else:
with torch.no_grad():
# https://github.com/HuangJunJie2017/BEVDet/issues/278
cost_volumn = self.calculate_cost_volumn(stereo_metas) # (B*N_views, D, fH_stereo, fW_stereo)
cost_volumn = self.cost_volumn_net(cost_volumn) # (B*N_views, D, fH, fW)
depth = torch.cat([depth, cost_volumn], dim=1) # (B*N_views, C_mid+D, fH, fW)
if self.with_cp:
depth = checkpoint(self.depth_conv, depth)
else:
# 3*res blocks +ASPP/DCN + Conv(c_mid-->D)
depth = self.depth_conv(depth) # x: (B*N_views, C_mid, fH, fW) --> (B*N_views, D, fH, fW)
return torch.cat([depth, context], dim=1)
class DepthAggregation(nn.Module):
"""pixel cloud feature extraction."""
def __init__(self, in_channels, mid_channels, out_channels):
super(DepthAggregation, self).__init__()
self.reduce_conv = nn.Sequential(
nn.Conv2d(
in_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
self.conv = nn.Sequential(
nn.Conv2d(
mid_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(
mid_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
self.out_conv = nn.Sequential(
nn.Conv2d(
mid_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True),
# nn.BatchNorm3d(out_channels),
# nn.ReLU(inplace=True),
)
@autocast(False)
def forward(self, x):
x = checkpoint(self.reduce_conv, x)
short_cut = x
x = checkpoint(self.conv, x)
x = short_cut + x
x = self.out_conv(x)
return x
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment