Commit 19472568 authored by 雍大凯's avatar 雍大凯
Browse files

将子模块转换为普通目录

parent 51e55208
from .maptr_head import MapTRHead
from .maptrv2_head import MapTRv2Head
\ No newline at end of file
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models import HEADS, build_loss
from mmdet.models.dense_heads import DETRHead
from mmdet3d.core.bbox.coders import build_bbox_coder
from mmcv.runner import force_fp32, auto_fp16
from mmcv.cnn import Linear, bias_init_with_prob, xavier_init, constant_init
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.core.bbox.transforms import bbox_xyxy_to_cxcywh, bbox_cxcywh_to_xyxy
from mmdet.core import (multi_apply, multi_apply, reduce_mean)
from mmcv.utils import TORCH_VERSION, digit_version
def normalize_2d_bbox(bboxes, pc_range):
patch_h = pc_range[4]-pc_range[1]
patch_w = pc_range[3]-pc_range[0]
cxcywh_bboxes = bbox_xyxy_to_cxcywh(bboxes)
cxcywh_bboxes[...,0:1] = cxcywh_bboxes[..., 0:1] - pc_range[0]
cxcywh_bboxes[...,1:2] = cxcywh_bboxes[...,1:2] - pc_range[1]
factor = bboxes.new_tensor([patch_w, patch_h,patch_w,patch_h])
normalized_bboxes = cxcywh_bboxes / factor
return normalized_bboxes
def normalize_2d_pts(pts, pc_range):
patch_h = pc_range[4]-pc_range[1]
patch_w = pc_range[3]-pc_range[0]
new_pts = pts.clone()
new_pts[...,0:1] = pts[..., 0:1] - pc_range[0]
new_pts[...,1:2] = pts[...,1:2] - pc_range[1]
factor = pts.new_tensor([patch_w, patch_h])
normalized_pts = new_pts / factor
return normalized_pts
def denormalize_2d_bbox(bboxes, pc_range):
bboxes = bbox_cxcywh_to_xyxy(bboxes)
bboxes[..., 0::2] = (bboxes[..., 0::2]*(pc_range[3] -
pc_range[0]) + pc_range[0])
bboxes[..., 1::2] = (bboxes[..., 1::2]*(pc_range[4] -
pc_range[1]) + pc_range[1])
return bboxes
def denormalize_2d_pts(pts, pc_range):
new_pts = pts.clone()
new_pts[...,0:1] = (pts[..., 0:1]*(pc_range[3] -
pc_range[0]) + pc_range[0])
new_pts[...,1:2] = (pts[...,1:2]*(pc_range[4] -
pc_range[1]) + pc_range[1])
return new_pts
@HEADS.register_module()
class MapTRHead(DETRHead):
"""Head of Detr3D.
Args:
with_box_refine (bool): Whether to refine the reference points
in the decoder. Defaults to False.
as_two_stage (bool) : Whether to generate the proposal from
the outputs of encoder.
transformer (obj:`ConfigDict`): ConfigDict is used for building
the Encoder and Decoder.
bev_h, bev_w (int): spatial shape of BEV queries.
"""
def __init__(self,
*args,
with_box_refine=False,
as_two_stage=False,
transformer=None,
bbox_coder=None,
num_cls_fcs=2,
code_weights=None,
bev_h=30,
bev_w=30,
num_vec=20,
num_pts_per_vec=2,
num_pts_per_gt_vec=2,
query_embed_type='all_pts',
transform_method='minmax',
gt_shift_pts_pattern='v0',
dir_interval=1,
loss_pts=dict(type='ChamferDistance',
loss_src_weight=1.0,
loss_dst_weight=1.0),
loss_dir=dict(type='PtsDirCosLoss', loss_weight=2.0),
**kwargs):
self.bev_h = bev_h
self.bev_w = bev_w
self.fp16_enabled = False
self.with_box_refine = with_box_refine
self.as_two_stage = as_two_stage
self.bev_encoder_type = transformer.encoder.type
if self.as_two_stage:
transformer['as_two_stage'] = self.as_two_stage
if 'code_size' in kwargs:
self.code_size = kwargs['code_size']
else:
self.code_size = 10
if code_weights is not None:
self.code_weights = code_weights
else:
self.code_weights = [1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
self.bbox_coder = build_bbox_coder(bbox_coder)
self.pc_range = self.bbox_coder.pc_range
self.real_w = self.pc_range[3] - self.pc_range[0]
self.real_h = self.pc_range[4] - self.pc_range[1]
self.num_cls_fcs = num_cls_fcs - 1
self.query_embed_type = query_embed_type
self.transform_method = transform_method
self.gt_shift_pts_pattern = gt_shift_pts_pattern
num_query = num_vec * num_pts_per_vec
self.num_query = num_query
self.num_vec = num_vec
self.num_pts_per_vec = num_pts_per_vec
self.num_pts_per_gt_vec = num_pts_per_gt_vec
self.dir_interval = dir_interval
super(MapTRHead, self).__init__(
*args, transformer=transformer, **kwargs)
self.code_weights = nn.Parameter(torch.tensor(
self.code_weights, requires_grad=False), requires_grad=False)
self.loss_pts = build_loss(loss_pts)
self.loss_dir = build_loss(loss_dir)
num_query = num_vec * num_pts_per_vec
self.num_query = num_query
self.num_vec = num_vec
self.num_pts_per_vec = num_pts_per_vec
self.num_pts_per_gt_vec = num_pts_per_gt_vec
self._init_layers()
def _init_layers(self):
"""Initialize classification branch and regression branch of head."""
cls_branch = []
# cls_branch.append(Linear(self.embed_dims * 2, self.embed_dims))
# cls_branch.append(nn.LayerNorm(self.embed_dims))
# cls_branch.append(nn.ReLU(inplace=True))
for _ in range(self.num_reg_fcs):
cls_branch.append(Linear(self.embed_dims, self.embed_dims))
cls_branch.append(nn.LayerNorm(self.embed_dims))
cls_branch.append(nn.ReLU(inplace=True))
cls_branch.append(Linear(self.embed_dims, self.cls_out_channels))
fc_cls = nn.Sequential(*cls_branch)
reg_branch = []
for _ in range(self.num_reg_fcs):
reg_branch.append(Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.ReLU())
reg_branch.append(Linear(self.embed_dims, self.code_size))
reg_branch = nn.Sequential(*reg_branch)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
# last reg_branch is used to generate proposal from
# encode feature map when as_two_stage is True.
num_pred = (self.transformer.decoder.num_layers + 1) if \
self.as_two_stage else self.transformer.decoder.num_layers
if self.with_box_refine:
self.cls_branches = _get_clones(fc_cls, num_pred)
self.reg_branches = _get_clones(reg_branch, num_pred)
else:
self.cls_branches = nn.ModuleList(
[fc_cls for _ in range(num_pred)])
self.reg_branches = nn.ModuleList(
[reg_branch for _ in range(num_pred)])
if not self.as_two_stage:
if self.bev_encoder_type == 'BEVFormerEncoder':
self.bev_embedding = nn.Embedding(
self.bev_h * self.bev_w, self.embed_dims)
else:
self.bev_embedding = None
if self.query_embed_type == 'all_pts':
self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims * 2)
elif self.query_embed_type == 'instance_pts':
self.query_embedding = None
self.instance_embedding = nn.Embedding(self.num_vec, self.embed_dims * 2)
self.pts_embedding = nn.Embedding(self.num_pts_per_vec, self.embed_dims * 2)
def init_weights(self):
"""Initialize weights of the DeformDETR head."""
self.transformer.init_weights()
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
for m in self.cls_branches:
nn.init.constant_(m[-1].bias, bias_init)
# for m in self.reg_branches:
# constant_init(m[-1], 0, bias=0)
# nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], 0.)
# @auto_fp16(apply_to=('mlvl_feats'))
@force_fp32(apply_to=('mlvl_feats', 'prev_bev'))
def forward(self, mlvl_feats, lidar_feat, img_metas, prev_bev=None, only_bev=False):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 5D-tensor with shape
(B, N, C, H, W).
prev_bev: previous bev featues
only_bev: only compute BEV features with encoder.
Returns:
all_cls_scores (Tensor): Outputs from the classification head, \
shape [nb_dec, bs, num_query, cls_out_channels]. Note \
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression \
head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \
Shape [nb_dec, bs, num_query, 9].
"""
bs, num_cam, _, _, _ = mlvl_feats[0].shape
dtype = mlvl_feats[0].dtype
# import pdb;pdb.set_trace()
if self.query_embed_type == 'all_pts':
object_query_embeds = self.query_embedding.weight.to(dtype)
elif self.query_embed_type == 'instance_pts':
pts_embeds = self.pts_embedding.weight.unsqueeze(0)
instance_embeds = self.instance_embedding.weight.unsqueeze(1)
object_query_embeds = (pts_embeds + instance_embeds).flatten(0, 1).to(dtype)
if self.bev_embedding is not None:
bev_queries = self.bev_embedding.weight.to(dtype)
bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
device=bev_queries.device).to(dtype)
bev_pos = self.positional_encoding(bev_mask).to(dtype)
else:
bev_queries = None
bev_mask = None
bev_pos = None
if only_bev: # only use encoder to obtain BEV features, TODO: refine the workaround
return self.transformer.get_bev_features(
mlvl_feats,
lidar_feat,
bev_queries,
self.bev_h,
self.bev_w,
grid_length=(self.real_h / self.bev_h,
self.real_w / self.bev_w),
bev_pos=bev_pos,
img_metas=img_metas,
prev_bev=prev_bev,
)
else:
outputs = self.transformer(
mlvl_feats,
lidar_feat,
bev_queries,
object_query_embeds,
self.bev_h,
self.bev_w,
grid_length=(self.real_h / self.bev_h,
self.real_w / self.bev_w),
bev_pos=bev_pos,
reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
cls_branches=self.cls_branches if self.as_two_stage else None,
img_metas=img_metas,
prev_bev=prev_bev
)
bev_embed, hs, init_reference, inter_references = outputs
hs = hs.permute(0, 2, 1, 3)
outputs_classes = []
outputs_coords = []
outputs_pts_coords = []
for lvl in range(hs.shape[0]):
if lvl == 0:
# import pdb;pdb.set_trace()
reference = init_reference
else:
reference = inter_references[lvl - 1]
reference = inverse_sigmoid(reference)
# import pdb;pdb.set_trace()
# vec_embedding = hs[lvl].reshape(bs, self.num_vec, -1)
outputs_class = self.cls_branches[lvl](hs[lvl]
.view(bs,self.num_vec, self.num_pts_per_vec,-1)
.mean(2))
tmp = self.reg_branches[lvl](hs[lvl])
# TODO: check the shape of reference
assert reference.shape[-1] == 2
tmp[..., 0:2] += reference[..., 0:2]
# tmp[..., 0:2] = tmp[..., 0:2].sigmoid()
tmp = tmp.sigmoid() # cx,cy,w,h
# import pdb;pdb.set_trace()
# tmp[..., 0:1] = (tmp[..., 0:1] * (self.pc_range[3] -
# self.pc_range[0]) + self.pc_range[0])
# tmp[..., 1:2] = (tmp[..., 1:2] * (self.pc_range[4] -
# self.pc_range[1]) + self.pc_range[1])
# tmp = tmp.reshape(bs, self.num_vec,-1)
# TODO: check if using sigmoid
outputs_coord, outputs_pts_coord = self.transform_box(tmp)
outputs_classes.append(outputs_class)
outputs_coords.append(outputs_coord)
outputs_pts_coords.append(outputs_pts_coord)
outputs_classes = torch.stack(outputs_classes)
outputs_coords = torch.stack(outputs_coords)
outputs_pts_coords = torch.stack(outputs_pts_coords)
outs = {
'bev_embed': bev_embed,
'all_cls_scores': outputs_classes,
'all_bbox_preds': outputs_coords,
'all_pts_preds': outputs_pts_coords,
'enc_cls_scores': None,
'enc_bbox_preds': None,
'enc_pts_preds': None
}
return outs
def transform_box(self, pts, y_first=False):
"""
Converting the points set into bounding box.
Args:
pts: the input points sets (fields), each points
set (fields) is represented as 2n scalar.
y_first: if y_fisrt=True, the point set is represented as
[y1, x1, y2, x2 ... yn, xn], otherwise the point set is
represented as [x1, y1, x2, y2 ... xn, yn].
Returns:
The bbox [cx, cy, w, h] transformed from points.
"""
pts_reshape = pts.view(pts.shape[0], self.num_vec,
self.num_pts_per_vec,2)
pts_y = pts_reshape[:, :, :, 0] if y_first else pts_reshape[:, :, :, 1]
pts_x = pts_reshape[:, :, :, 1] if y_first else pts_reshape[:, :, :, 0]
if self.transform_method == 'minmax':
# import pdb;pdb.set_trace()
xmin = pts_x.min(dim=2, keepdim=True)[0]
xmax = pts_x.max(dim=2, keepdim=True)[0]
ymin = pts_y.min(dim=2, keepdim=True)[0]
ymax = pts_y.max(dim=2, keepdim=True)[0]
bbox = torch.cat([xmin, ymin, xmax, ymax], dim=2)
bbox = bbox_xyxy_to_cxcywh(bbox)
else:
raise NotImplementedError
return bbox, pts_reshape
def _get_target_single(self,
cls_score,
bbox_pred,
pts_pred,
gt_labels,
gt_bboxes,
gt_shifts_pts,
gt_bboxes_ignore=None):
""""Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
for one image, with normalized coordinate (cx, cy, w, h) and
shape [num_query, 4].
gt_bboxes (Tensor): Ground truth bboxes for one image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (Tensor): Ground truth class indices for one image
with shape (num_gts, ).
gt_bboxes_ignore (Tensor, optional): Bounding boxes
which can be ignored. Default None.
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
- label_weights (Tensor]): Label weights of each image.
- bbox_targets (Tensor): BBox targets of each image.
- bbox_weights (Tensor): BBox weights of each image.
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
# import pdb;pdb.set_trace()
num_bboxes = bbox_pred.size(0)
# assigner and sampler
gt_c = gt_bboxes.shape[-1]
# import pdb;pdb.set_trace()
assign_result, order_index = self.assigner.assign(bbox_pred, cls_score, pts_pred,
gt_bboxes, gt_labels, gt_shifts_pts,
gt_bboxes_ignore)
sampling_result = self.sampler.sample(assign_result, bbox_pred,
gt_bboxes)
# pts_sampling_result = self.sampler.sample(assign_result, pts_pred,
# gt_pts)
# import pdb;pdb.set_trace()
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
# label targets
labels = gt_bboxes.new_full((num_bboxes,),
self.num_classes,
dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_bboxes.new_ones(num_bboxes)
# bbox targets
bbox_targets = torch.zeros_like(bbox_pred)[..., :gt_c]
bbox_weights = torch.zeros_like(bbox_pred)
bbox_weights[pos_inds] = 1.0
# pts targets
# import pdb;pdb.set_trace()
# pts_targets = torch.zeros_like(pts_pred)
# num_query, num_order, num_points, num_coords
if order_index is None:
# import pdb;pdb.set_trace()
assigned_shift = gt_labels[sampling_result.pos_assigned_gt_inds]
else:
assigned_shift = order_index[sampling_result.pos_inds, sampling_result.pos_assigned_gt_inds]
pts_targets = pts_pred.new_zeros((pts_pred.size(0),
pts_pred.size(1), pts_pred.size(2)))
pts_weights = torch.zeros_like(pts_targets)
pts_weights[pos_inds] = 1.0
# DETR
bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes
pts_targets[pos_inds] = gt_shifts_pts[sampling_result.pos_assigned_gt_inds,assigned_shift,:,:]
return (labels, label_weights, bbox_targets, bbox_weights,
pts_targets, pts_weights,
pos_inds, neg_inds)
def get_targets(self,
cls_scores_list,
bbox_preds_list,
pts_preds_list,
gt_bboxes_list,
gt_labels_list,
gt_shifts_pts_list,
gt_bboxes_ignore_list=None):
""""Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
decoder layer for each image, with normalized coordinate
(cx, cy, w, h) and shape [num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \
images.
- bbox_targets_list (list[Tensor]): BBox targets for all \
images.
- bbox_weights_list (list[Tensor]): BBox weights for all \
images.
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
"""
assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.'
num_imgs = len(cls_scores_list)
gt_bboxes_ignore_list = [
gt_bboxes_ignore_list for _ in range(num_imgs)
]
(labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, pts_targets_list, pts_weights_list,
pos_inds_list, neg_inds_list) = multi_apply(
self._get_target_single, cls_scores_list, bbox_preds_list,pts_preds_list,
gt_labels_list, gt_bboxes_list, gt_shifts_pts_list, gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, pts_targets_list, pts_weights_list,
num_total_pos, num_total_neg)
def loss_single(self,
cls_scores,
bbox_preds,
pts_preds,
gt_bboxes_list,
gt_labels_list,
gt_shifts_pts_list,
gt_bboxes_ignore_list=None):
""""Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
for all images, with normalized coordinate (cx, cy, w, h) and
shape [bs, num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_pts_list (list[Tensor]): Ground truth pts for each image
with shape (num_gts, fixed_num, 2) in [x,y] format.
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
pts_preds_list = [pts_preds[i] for i in range(num_imgs)]
# import pdb;pdb.set_trace()
cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,pts_preds_list,
gt_bboxes_list, gt_labels_list,gt_shifts_pts_list,
gt_bboxes_ignore_list)
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
pts_targets_list, pts_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
# import pdb;pdb.set_trace()
labels = torch.cat(labels_list, 0)
label_weights = torch.cat(label_weights_list, 0)
bbox_targets = torch.cat(bbox_targets_list, 0)
bbox_weights = torch.cat(bbox_weights_list, 0)
pts_targets = torch.cat(pts_targets_list, 0)
pts_weights = torch.cat(pts_weights_list, 0)
# classification loss
cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean(
cls_scores.new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1)
loss_cls = self.loss_cls(
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
# import pdb;pdb.set_trace()
# regression L1 loss
bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))
normalized_bbox_targets = normalize_2d_bbox(bbox_targets, self.pc_range)
# normalized_bbox_targets = bbox_targets
isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)
bbox_weights = bbox_weights * self.code_weights
loss_bbox = self.loss_bbox(
bbox_preds[isnotnan, :4], normalized_bbox_targets[isnotnan,
:4], bbox_weights[isnotnan, :4],
avg_factor=num_total_pos)
# regression pts CD loss
# pts_preds = pts_preds
# import pdb;pdb.set_trace()
# num_samples, num_order, num_pts, num_coords
normalized_pts_targets = normalize_2d_pts(pts_targets, self.pc_range)
# num_samples, num_pts, num_coords
pts_preds = pts_preds.reshape(-1, pts_preds.size(-2),pts_preds.size(-1))
if self.num_pts_per_vec != self.num_pts_per_gt_vec:
pts_preds = pts_preds.permute(0,2,1)
pts_preds = F.interpolate(pts_preds, size=(self.num_pts_per_gt_vec), mode='linear',
align_corners=True)
pts_preds = pts_preds.permute(0,2,1).contiguous()
# import pdb;pdb.set_trace()
loss_pts = self.loss_pts(
pts_preds[isnotnan,:,:], normalized_pts_targets[isnotnan,
:,:],
pts_weights[isnotnan,:,:],
avg_factor=num_total_pos)
dir_weights = pts_weights[:, :-self.dir_interval,0]
denormed_pts_preds = denormalize_2d_pts(pts_preds, self.pc_range)
denormed_pts_preds_dir = denormed_pts_preds[:,self.dir_interval:,:] - denormed_pts_preds[:,:-self.dir_interval,:]
pts_targets_dir = pts_targets[:, self.dir_interval:,:] - pts_targets[:,:-self.dir_interval,:]
# dir_weights = pts_weights[:, indice,:-1,0]
# import pdb;pdb.set_trace()
loss_dir = self.loss_dir(
denormed_pts_preds_dir[isnotnan,:,:], pts_targets_dir[isnotnan,
:,:],
dir_weights[isnotnan,:],
avg_factor=num_total_pos)
bboxes = denormalize_2d_bbox(bbox_preds, self.pc_range)
# regression IoU loss, defaultly GIoU loss
loss_iou = self.loss_iou(
bboxes[isnotnan, :4], bbox_targets[isnotnan, :4], bbox_weights[isnotnan, :4],
avg_factor=num_total_pos)
if digit_version(TORCH_VERSION) >= digit_version('1.8'):
loss_cls = torch.nan_to_num(loss_cls)
loss_bbox = torch.nan_to_num(loss_bbox)
loss_iou = torch.nan_to_num(loss_iou)
loss_pts = torch.nan_to_num(loss_pts)
loss_dir = torch.nan_to_num(loss_dir)
return loss_cls, loss_bbox, loss_iou, loss_pts, loss_dir
@force_fp32(apply_to=('preds_dicts'))
def loss(self,
gt_bboxes_list,
gt_labels_list,
preds_dicts,
gt_bboxes_ignore=None,
img_metas=None):
""""Loss function.
Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
preds_dicts:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
enc_cls_scores (Tensor): Classification scores of
points on encode feature map , has shape
(N, h*w, num_classes). Only be passed when as_two_stage is
True, otherwise is None.
enc_bbox_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, 4). Only be
passed when as_two_stage is True, otherwise is None.
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert gt_bboxes_ignore is None, \
f'{self.__class__.__name__} only supports ' \
f'for gt_bboxes_ignore setting to None.'
gt_vecs_list = copy.deepcopy(gt_bboxes_list)
# import pdb;pdb.set_trace()
all_cls_scores = preds_dicts['all_cls_scores']
all_bbox_preds = preds_dicts['all_bbox_preds']
all_pts_preds = preds_dicts['all_pts_preds']
enc_cls_scores = preds_dicts['enc_cls_scores']
enc_bbox_preds = preds_dicts['enc_bbox_preds']
enc_pts_preds = preds_dicts['enc_pts_preds']
num_dec_layers = len(all_cls_scores)
device = gt_labels_list[0].device
# gt_bboxes_list = [torch.cat(
# (gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]),
# dim=1).to(device) for gt_bboxes in gt_bboxes_list]
# import pdb;pdb.set_trace()
# gt_bboxes_list = [
# gt_bboxes.to(device) for gt_bboxes in gt_bboxes_list]
gt_bboxes_list = [
gt_bboxes.bbox.to(device) for gt_bboxes in gt_vecs_list]
gt_pts_list = [
gt_bboxes.fixed_num_sampled_points.to(device) for gt_bboxes in gt_vecs_list]
if self.gt_shift_pts_pattern == 'v0':
gt_shifts_pts_list = [
gt_bboxes.shift_fixed_num_sampled_points.to(device) for gt_bboxes in gt_vecs_list]
elif self.gt_shift_pts_pattern == 'v1':
gt_shifts_pts_list = [
gt_bboxes.shift_fixed_num_sampled_points_v1.to(device) for gt_bboxes in gt_vecs_list]
elif self.gt_shift_pts_pattern == 'v2':
gt_shifts_pts_list = [
gt_bboxes.shift_fixed_num_sampled_points_v2.to(device) for gt_bboxes in gt_vecs_list]
elif self.gt_shift_pts_pattern == 'v3':
gt_shifts_pts_list = [
gt_bboxes.shift_fixed_num_sampled_points_v3.to(device) for gt_bboxes in gt_vecs_list]
elif self.gt_shift_pts_pattern == 'v4':
gt_shifts_pts_list = [
gt_bboxes.shift_fixed_num_sampled_points_v4.to(device) for gt_bboxes in gt_vecs_list]
else:
raise NotImplementedError
all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
all_gt_pts_list = [gt_pts_list for _ in range(num_dec_layers)]
all_gt_shifts_pts_list = [gt_shifts_pts_list for _ in range(num_dec_layers)]
all_gt_bboxes_ignore_list = [
gt_bboxes_ignore for _ in range(num_dec_layers)
]
# import pdb;pdb.set_trace()
losses_cls, losses_bbox, losses_iou, losses_pts, losses_dir = multi_apply(
self.loss_single, all_cls_scores, all_bbox_preds,all_pts_preds,
all_gt_bboxes_list, all_gt_labels_list,all_gt_shifts_pts_list,
all_gt_bboxes_ignore_list)
loss_dict = dict()
# loss of proposal generated from encode feature map.
if enc_cls_scores is not None:
binary_labels_list = [
torch.zeros_like(gt_labels_list[i])
for i in range(len(all_gt_labels_list))
]
# TODO bug here
enc_loss_cls, enc_losses_bbox, enc_losses_iou, enc_losses_pts, enc_losses_dir = \
self.loss_single(enc_cls_scores, enc_bbox_preds, enc_pts_preds,
gt_bboxes_list, binary_labels_list, gt_pts_list,gt_bboxes_ignore)
loss_dict['enc_loss_cls'] = enc_loss_cls
loss_dict['enc_loss_bbox'] = enc_losses_bbox
loss_dict['enc_losses_iou'] = enc_losses_iou
loss_dict['enc_losses_pts'] = enc_losses_pts
loss_dict['enc_losses_dir'] = enc_losses_dir
# loss from the last decoder layer
loss_dict['loss_cls'] = losses_cls[-1]
loss_dict['loss_bbox'] = losses_bbox[-1]
loss_dict['loss_iou'] = losses_iou[-1]
loss_dict['loss_pts'] = losses_pts[-1]
loss_dict['loss_dir'] = losses_dir[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_bbox_i, loss_iou_i, loss_pts_i, loss_dir_i in zip(losses_cls[:-1],
losses_bbox[:-1],
losses_iou[:-1],
losses_pts[:-1],
losses_dir[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
loss_dict[f'd{num_dec_layer}.loss_pts'] = loss_pts_i
loss_dict[f'd{num_dec_layer}.loss_dir'] = loss_dir_i
num_dec_layer += 1
return loss_dict
@force_fp32(apply_to=('preds_dicts'))
def get_bboxes(self, preds_dicts, img_metas, rescale=False):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
# bboxes: xmin, ymin, xmax, ymax
preds_dicts = self.bbox_coder.decode(preds_dicts)
num_samples = len(preds_dicts)
ret_list = []
for i in range(num_samples):
preds = preds_dicts[i]
bboxes = preds['bboxes']
# bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
# code_size = bboxes.shape[-1]
# bboxes = img_metas[i]['box_type_3d'](bboxes, code_size)
scores = preds['scores']
labels = preds['labels']
pts = preds['pts']
ret_list.append([bboxes, scores, labels, pts])
return ret_list
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models import HEADS, build_loss
from mmdet.models.dense_heads import DETRHead
from mmdet3d.core.bbox.coders import build_bbox_coder
from mmcv.runner import force_fp32, auto_fp16
from mmcv.cnn import Linear, bias_init_with_prob, xavier_init, constant_init
from mmdet.models.utils.transformer import inverse_sigmoid
from mmdet.core.bbox.transforms import bbox_xyxy_to_cxcywh, bbox_cxcywh_to_xyxy
from mmdet.core import (multi_apply, multi_apply, reduce_mean)
from mmcv.utils import TORCH_VERSION, digit_version
def denormalize_3d_pts(pts, pc_range):
new_pts = pts.clone()
new_pts[...,0:1] = (pts[..., 0:1]*(pc_range[3] -
pc_range[0]) + pc_range[0])
new_pts[...,1:2] = (pts[...,1:2]*(pc_range[4] -
pc_range[1]) + pc_range[1])
new_pts[...,2:3] = (pts[...,2:3]*(pc_range[5] -
pc_range[2]) + pc_range[2])
return new_pts
#@torch.compile(mode="max-autotune-no-cudagraphs")
def normalize_3d_pts(pts, pc_range):
patch_h = pc_range[4]-pc_range[1]
patch_w = pc_range[3]-pc_range[0]
patch_z = pc_range[5]-pc_range[2]
new_pts = pts.clone()
new_pts[...,0:1] = pts[..., 0:1] - pc_range[0]
new_pts[...,1:2] = pts[...,1:2] - pc_range[1]
new_pts[...,2:3] = pts[...,2:3] - pc_range[2]
factor = pts.new_tensor([patch_w, patch_h,patch_z])
normalized_pts = new_pts / factor
return normalized_pts
#@torch.compile(mode="max-autotune-no-cudagraphs")
def normalize_2d_bbox(bboxes, pc_range):
patch_h = pc_range[4]-pc_range[1]
patch_w = pc_range[3]-pc_range[0]
cxcywh_bboxes = bbox_xyxy_to_cxcywh(bboxes)
cxcywh_bboxes[...,0:1] = cxcywh_bboxes[..., 0:1] - pc_range[0]
cxcywh_bboxes[...,1:2] = cxcywh_bboxes[...,1:2] - pc_range[1]
factor = bboxes.new_tensor([patch_w, patch_h,patch_w,patch_h])
normalized_bboxes = cxcywh_bboxes / factor
return normalized_bboxes
#@torch.compile(mode="max-autotune-no-cudagraphs")
def normalize_2d_pts(pts, pc_range):
patch_h = pc_range[4]-pc_range[1]
patch_w = pc_range[3]-pc_range[0]
new_pts = pts.clone()
new_pts[...,0:1] = pts[..., 0:1] - pc_range[0]
new_pts[...,1:2] = pts[...,1:2] - pc_range[1]
factor = pts.new_tensor([patch_w, patch_h])
normalized_pts = new_pts / factor
return normalized_pts
#@torch.compile(mode="max-autotune-no-cudagraphs")
def denormalize_2d_bbox(bboxes, pc_range):
bboxes = bbox_cxcywh_to_xyxy(bboxes)
bboxes[..., 0::2] = (bboxes[..., 0::2]*(pc_range[3] -
pc_range[0]) + pc_range[0])
bboxes[..., 1::2] = (bboxes[..., 1::2]*(pc_range[4] -
pc_range[1]) + pc_range[1])
return bboxes
#@torch.compile(mode="max-autotune-no-cudagraphs")
def denormalize_2d_pts(pts, pc_range):
new_pts = pts.clone()
new_pts[...,0:1] = (pts[..., 0:1]*(pc_range[3] -
pc_range[0]) + pc_range[0])
new_pts[...,1:2] = (pts[...,1:2]*(pc_range[4] -
pc_range[1]) + pc_range[1])
return new_pts
@HEADS.register_module()
class MapTRv2Head(DETRHead):
"""Head of Detr3D.
Args:
with_box_refine (bool): Whether to refine the reference points
in the decoder. Defaults to False.
as_two_stage (bool) : Whether to generate the proposal from
the outputs of encoder.
transformer (obj:`ConfigDict`): ConfigDict is used for building
the Encoder and Decoder.
bev_h, bev_w (int): spatial shape of BEV queries.
"""
def __init__(self,
*args,
with_box_refine=False,
as_two_stage=False,
transformer=None,
bbox_coder=None,
num_cls_fcs=2,
code_weights=None,
bev_h=30,
bev_w=30,
# num_vec=20,
num_vec_one2one=50,
num_vec_one2many=0,
k_one2many=0,
lambda_one2many=1,
num_pts_per_vec=2,
num_pts_per_gt_vec=2,
query_embed_type='all_pts',
transform_method='minmax',
gt_shift_pts_pattern='v0',
dir_interval=1,
aux_seg = dict(
use_aux_seg=False,
bev_seg=False,
pv_seg=False,
seg_classes=1,
feat_down_sample=32,
),
z_cfg = dict(
pred_z_flag=False,
gt_z_flag=False,
),
loss_pts=dict(type='ChamferDistance',
loss_src_weight=1.0,
loss_dst_weight=1.0),
loss_seg=dict(type='SimpleLoss',
pos_weight=2.13,
loss_weight=1.0),
loss_pv_seg=dict(type='SimpleLoss',
pos_weight=2.13,
loss_weight=1.0),
loss_dir=dict(type='PtsDirCosLoss', loss_weight=2.0),
**kwargs):
self.bev_h = bev_h
self.bev_w = bev_w
self.fp16_enabled = False
self.with_box_refine = with_box_refine
self.as_two_stage = as_two_stage
self.bev_encoder_type = transformer.encoder.type
if self.as_two_stage:
transformer['as_two_stage'] = self.as_two_stage
if 'code_size' in kwargs:
self.code_size = 2 if not z_cfg['pred_z_flag'] else 3
else:
self.code_size = 2
if code_weights is not None:
self.code_weights = code_weights
else:
self.code_weights = [1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2]
self.bbox_coder = build_bbox_coder(bbox_coder)
self.pc_range = self.bbox_coder.pc_range
self.real_w = self.pc_range[3] - self.pc_range[0]
self.real_h = self.pc_range[4] - self.pc_range[1]
self.num_cls_fcs = num_cls_fcs - 1
self.query_embed_type = query_embed_type
self.transform_method = transform_method
self.gt_shift_pts_pattern = gt_shift_pts_pattern
num_vec = num_vec_one2one + num_vec_one2many
num_query = num_vec * num_pts_per_vec
self.num_query = num_query
self.num_vec = num_vec
self.num_pts_per_vec = num_pts_per_vec
self.num_pts_per_gt_vec = num_pts_per_gt_vec
self.dir_interval = dir_interval
self.aux_seg = aux_seg
self.z_cfg = z_cfg
super(MapTRv2Head, self).__init__(
*args, transformer=transformer, **kwargs)
self.code_weights = nn.Parameter(torch.tensor(
self.code_weights, requires_grad=False), requires_grad=False)
self.loss_pts = build_loss(loss_pts)
self.loss_dir = build_loss(loss_dir)
num_query = num_vec * num_pts_per_vec
self.num_query = num_query
self.num_vec = num_vec
self.num_pts_per_vec = num_pts_per_vec
self.num_pts_per_gt_vec = num_pts_per_gt_vec
self.num_vec_one2one = num_vec_one2one
self.num_vec_one2many = num_vec_one2many
self.k_one2many = k_one2many
self.lambda_one2many=lambda_one2many
self.loss_seg = build_loss(loss_seg)
self.loss_pv_seg = build_loss(loss_pv_seg)
self._init_layers()
def _init_layers(self):
"""Initialize classification branch and regression branch of head."""
cls_branch = []
# cls_branch.append(Linear(self.embed_dims * 2, self.embed_dims))
# cls_branch.append(nn.LayerNorm(self.embed_dims))
# cls_branch.append(nn.ReLU(inplace=True))
for _ in range(self.num_reg_fcs):
cls_branch.append(Linear(self.embed_dims, self.embed_dims))
cls_branch.append(nn.LayerNorm(self.embed_dims))
cls_branch.append(nn.ReLU(inplace=True))
cls_branch.append(Linear(self.embed_dims, self.cls_out_channels))
fc_cls = nn.Sequential(*cls_branch)
reg_branch = []
for _ in range(self.num_reg_fcs):
reg_branch.append(Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.ReLU())
reg_branch.append(Linear(self.embed_dims, self.code_size))
reg_branch = nn.Sequential(*reg_branch)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
# last reg_branch is used to generate proposal from
# encode feature map when as_two_stage is True.
num_pred = (self.transformer.decoder.num_layers + 1) if \
self.as_two_stage else self.transformer.decoder.num_layers
if self.with_box_refine:
self.cls_branches = _get_clones(fc_cls, num_pred)
self.reg_branches = _get_clones(reg_branch, num_pred)
else:
self.cls_branches = nn.ModuleList(
[fc_cls for _ in range(num_pred)])
self.reg_branches = nn.ModuleList(
[reg_branch for _ in range(num_pred)])
if self.aux_seg['use_aux_seg']:
if not (self.aux_seg['bev_seg'] or self.aux_seg['pv_seg']):
raise ValueError('aux_seg must have bev_seg or pv_seg')
if self.aux_seg['bev_seg']:
self.seg_head = nn.Sequential(
nn.Conv2d(self.embed_dims, self.embed_dims, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(self.embed_dims, self.aux_seg['seg_classes'], kernel_size=1, padding=0)
)
if self.aux_seg['pv_seg']:
self.pv_seg_head = nn.Sequential(
nn.Conv2d(self.embed_dims, self.embed_dims, kernel_size=3, padding=1, bias=False),
# nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(self.embed_dims, self.aux_seg['seg_classes'], kernel_size=1, padding=0)
)
if not self.as_two_stage:
if 'BEVFormerEncoder' in self.bev_encoder_type:
self.bev_embedding = nn.Embedding(
self.bev_h * self.bev_w, self.embed_dims)
else:
self.bev_embedding = None
if self.query_embed_type == 'all_pts':
self.query_embedding = nn.Embedding(self.num_query,
self.embed_dims * 2)
elif self.query_embed_type == 'instance_pts':
self.query_embedding = None
self.instance_embedding = nn.Embedding(self.num_vec, self.embed_dims * 2)
self.pts_embedding = nn.Embedding(self.num_pts_per_vec, self.embed_dims * 2)
def init_weights(self):
"""Initialize weights of the DeformDETR head."""
self.transformer.init_weights()
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
for m in self.cls_branches:
nn.init.constant_(m[-1].bias, bias_init)
# for m in self.reg_branches:
# constant_init(m[-1], 0, bias=0)
# nn.init.constant_(self.reg_branches[0][-1].bias.data[2:], 0.)
#@torch.compile(mode="max-autotune-no-cudagraphs") ####
def compute_decoder_predictions(self, outputs, bs, num_vec, mlvl_feats):
bev_embed,depth, hs, init_reference, inter_references = outputs
hs = hs.permute(0, 2, 1, 3)
outputs_classes_one2one = []
outputs_coords_one2one = []
outputs_pts_coords_one2one = []
outputs_classes_one2many = []
outputs_coords_one2many = []
outputs_pts_coords_one2many = []
for lvl in range(hs.shape[0]):
if lvl == 0:
# import pdb;pdb.set_trace()
reference = init_reference[...,0:2] if not self.z_cfg['gt_z_flag'] else init_reference[...,0:3]
else:
reference = inter_references[lvl - 1][...,0:2] if not self.z_cfg['gt_z_flag'] else inter_references[lvl - 1][...,0:3]
reference = inverse_sigmoid(reference)
# import pdb;pdb.set_trace()
# vec_embedding = hs[lvl].reshape(bs, self.num_vec, -1)
outputs_class = self.cls_branches[lvl](hs[lvl]
.view(bs,num_vec, self.num_pts_per_vec,-1)
.mean(2))
tmp = self.reg_branches[lvl](hs[lvl])
tmp = tmp[..., 0:2] if not self.z_cfg['gt_z_flag'] else tmp[..., 0:3]
# TODO: check the shape of reference
# assert reference.shape[-1] == 2
# tmp[..., 0:2] += reference[..., 0:2]
# assert reference.shape[-1] == 2
tmp += reference
tmp = tmp.sigmoid() # cx,cy,w,h
# if not self.z_cfg['gt_z_flag']:
# tmp = tmp[..., 0:2] if not self.z_cfg['gt_z_flag'] else tmp[..., 0:3]
# TODO: check if using sigmoid
outputs_coord, outputs_pts_coord = self.transform_box(tmp,num_vec=num_vec)
outputs_classes_one2one.append(outputs_class[:, 0:self.num_vec_one2one])
outputs_coords_one2one.append(outputs_coord[:, 0:self.num_vec_one2one])
outputs_pts_coords_one2one.append(outputs_pts_coord[:, 0:self.num_vec_one2one])
outputs_classes_one2many.append(outputs_class[:, self.num_vec_one2one:])
outputs_coords_one2many.append(outputs_coord[:, self.num_vec_one2one:])
outputs_pts_coords_one2many.append(outputs_pts_coord[:, self.num_vec_one2one:])
outputs_seg = None
outputs_pv_seg = None
if self.aux_seg['use_aux_seg']:
seg_bev_embed = bev_embed.permute(1,0,2).view(bs,self.bev_h, self.bev_w, -1).permute(0,3,1,2).contiguous()
if self.aux_seg['bev_seg']:
outputs_seg = self.seg_head(seg_bev_embed)
bs, num_cam, embed_dims, feat_h, feat_w = mlvl_feats[-1].shape
if self.aux_seg['pv_seg']:
outputs_pv_seg = self.pv_seg_head(mlvl_feats[-1].flatten(0,1))
outputs_pv_seg = outputs_pv_seg.view(bs, num_cam, -1, feat_h, feat_w)
return bev_embed, outputs_classes_one2one, outputs_coords_one2one, outputs_pts_coords_one2one, depth, outputs_seg, outputs_pv_seg, outputs_classes_one2many, outputs_coords_one2many, outputs_pts_coords_one2many
#@torch.compile(mode="max-autotune-no-cudagraphs")
def prepare_transformer_inputs(self, mlvl_feats):
if self.training:
num_vec = self.num_vec
else:
num_vec = self.num_vec_one2one
bs, num_cam, _, _, _ = mlvl_feats[0].shape
dtype = mlvl_feats[0].dtype
if self.query_embed_type == 'all_pts':
object_query_embeds = self.query_embedding.weight.to(dtype)
elif self.query_embed_type == 'instance_pts':
pts_embeds = self.pts_embedding.weight.unsqueeze(0)
instance_embeds = self.instance_embedding.weight[0:num_vec].unsqueeze(1)
object_query_embeds = (pts_embeds + instance_embeds).flatten(0, 1).to(dtype)
if self.bev_embedding is not None:
bev_queries = self.bev_embedding.weight.to(dtype)
bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
device=bev_queries.device).to(dtype)
bev_pos = self.positional_encoding(bev_mask).to(dtype)
else:
bev_queries = None
bev_mask = None
bev_pos = None
# make attn mask
""" attention mask to prevent information leakage
"""
self_attn_mask = (
torch.zeros([num_vec, num_vec,]).bool().to(mlvl_feats[0].device)
)
self_attn_mask[self.num_vec_one2one :, 0 : self.num_vec_one2one,] = True
self_attn_mask[0 : self.num_vec_one2one, self.num_vec_one2one :,] = True
return num_vec, object_query_embeds, bev_queries, bev_pos, self_attn_mask, bs
# @auto_fp16(apply_to=('mlvl_feats'))
@force_fp32(apply_to=('mlvl_feats', 'prev_bev'))
def forward(self, mlvl_feats, lidar_feat, img_metas, prev_bev=None, only_bev=False):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 5D-tensor with shape
(B, N, C, H, W).
prev_bev: previous bev featues
only_bev: only compute BEV features with encoder.
Returns:
all_cls_scores (Tensor): Outputs from the classification head, \
shape [nb_dec, bs, num_query, cls_out_channels]. Note \
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression \
head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \
Shape [nb_dec, bs, num_query, 9].
"""
# if self.training:
# num_vec = self.num_vec
# else:
# num_vec = self.num_vec_one2one
# # import ipdb;ipdb.set_trace()
# bs, num_cam, _, _, _ = mlvl_feats[0].shape
# dtype = mlvl_feats[0].dtype
# # import ipdb;ipdb.set_trace()
# if self.query_embed_type == 'all_pts':
# object_query_embeds = self.query_embedding.weight.to(dtype)
# elif self.query_embed_type == 'instance_pts':
# pts_embeds = self.pts_embedding.weight.unsqueeze(0)
# instance_embeds = self.instance_embedding.weight[0:num_vec].unsqueeze(1)
# object_query_embeds = (pts_embeds + instance_embeds).flatten(0, 1).to(dtype)
# if self.bev_embedding is not None:
# bev_queries = self.bev_embedding.weight.to(dtype)
# bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
# device=bev_queries.device).to(dtype)
# bev_pos = self.positional_encoding(bev_mask).to(dtype)
# else:
# bev_queries = None
# bev_mask = None
# bev_pos = None
# # make attn mask
# """ attention mask to prevent information leakage
# """
# self_attn_mask = (
# torch.zeros([num_vec, num_vec,]).bool().to(mlvl_feats[0].device)
# )
# self_attn_mask[self.num_vec_one2one :, 0 : self.num_vec_one2one,] = True
# self_attn_mask[0 : self.num_vec_one2one, self.num_vec_one2one :,] = True
num_vec, object_query_embeds, bev_queries, bev_pos, self_attn_mask, bs = self.prepare_transformer_inputs(mlvl_feats)
if only_bev: # only use encoder to obtain BEV features, TODO: refine the workaround
return self.transformer.get_bev_features(
mlvl_feats,
lidar_feat,
bev_queries,
self.bev_h,
self.bev_w,
grid_length=(self.real_h / self.bev_h,
self.real_w / self.bev_w),
bev_pos=bev_pos,
img_metas=img_metas,
prev_bev=prev_bev,
)['bev']
else:
outputs = self.transformer(
mlvl_feats,
lidar_feat,
bev_queries,
object_query_embeds,
self.bev_h,
self.bev_w,
grid_length=(self.real_h / self.bev_h,
self.real_w / self.bev_w),
bev_pos=bev_pos,
reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
cls_branches=self.cls_branches if self.as_two_stage else None,
img_metas=img_metas,
prev_bev=prev_bev,
self_attn_mask=self_attn_mask,
num_vec=num_vec,
num_pts_per_vec=self.num_pts_per_vec,
)
bev_embed, outputs_classes_one2one, outputs_coords_one2one, outputs_pts_coords_one2one, depth, outputs_seg, outputs_pv_seg, outputs_classes_one2many, outputs_coords_one2many, outputs_pts_coords_one2many = self.compute_decoder_predictions(outputs, bs, num_vec, mlvl_feats)
outputs_classes_one2one = torch.stack(outputs_classes_one2one)
outputs_coords_one2one = torch.stack(outputs_coords_one2one)
outputs_pts_coords_one2one = torch.stack(outputs_pts_coords_one2one)
outputs_classes_one2many = torch.stack(outputs_classes_one2many)
outputs_coords_one2many = torch.stack(outputs_coords_one2many)
outputs_pts_coords_one2many = torch.stack(outputs_pts_coords_one2many)
outs = {
'bev_embed': bev_embed,
'all_cls_scores': outputs_classes_one2one,
'all_bbox_preds': outputs_coords_one2one,
'all_pts_preds': outputs_pts_coords_one2one,
'enc_cls_scores': None,
'enc_bbox_preds': None,
'enc_pts_preds': None,
'depth': depth,
'seg': outputs_seg,
'pv_seg': outputs_pv_seg,
"one2many_outs": dict(
all_cls_scores=outputs_classes_one2many,
all_bbox_preds=outputs_coords_one2many,
all_pts_preds=outputs_pts_coords_one2many,
enc_cls_scores=None,
enc_bbox_preds=None,
enc_pts_preds=None,
seg=None,
pv_seg=None,
)
}
return outs
def transform_box(self, pts, num_vec=50, y_first=False):
"""
Converting the points set into bounding box.
Args:
pts: the input points sets (fields), each points
set (fields) is represented as 2n scalar.
y_first: if y_fisrt=True, the point set is represented as
[y1, x1, y2, x2 ... yn, xn], otherwise the point set is
represented as [x1, y1, x2, y2 ... xn, yn].
Returns:
The bbox [cx, cy, w, h] transformed from points.
"""
if self.z_cfg['gt_z_flag']:
pts_reshape = pts.view(pts.shape[0], num_vec,
self.num_pts_per_vec,3)
else:
pts_reshape = pts.view(pts.shape[0], num_vec,
self.num_pts_per_vec,2)
pts_y = pts_reshape[:, :, :, 0] if y_first else pts_reshape[:, :, :, 1]
pts_x = pts_reshape[:, :, :, 1] if y_first else pts_reshape[:, :, :, 0]
if self.transform_method == 'minmax':
# import pdb;pdb.set_trace()
xmin = pts_x.min(dim=2, keepdim=True)[0]
xmax = pts_x.max(dim=2, keepdim=True)[0]
ymin = pts_y.min(dim=2, keepdim=True)[0]
ymax = pts_y.max(dim=2, keepdim=True)[0]
bbox = torch.cat([xmin, ymin, xmax, ymax], dim=2)
bbox = bbox_xyxy_to_cxcywh(bbox)
else:
raise NotImplementedError
return bbox, pts_reshape
def get_label_result(self, sampling_result, gt_bboxes, gt_labels, bbox_pred, order_index, pts_pred, gt_shifts_pts):
num_bboxes = bbox_pred.size(0)
gt_c = gt_bboxes.shape[-1]
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
pos_assigned_gt_inds = sampling_result.pos_assigned_gt_inds
pos_gt_bboxes = sampling_result.pos_gt_bboxes
# label targets
labels = gt_bboxes.new_full((num_bboxes,), self.num_classes, dtype=torch.long)
labels[pos_inds] = gt_labels[pos_assigned_gt_inds]
label_weights = gt_bboxes.new_ones(num_bboxes)
# bbox targets
bbox_targets = torch.zeros_like(bbox_pred)[..., :gt_c]
bbox_weights = torch.zeros_like(bbox_pred)
bbox_weights[pos_inds] = 1.0
if order_index is None:
assigned_shift = gt_labels[pos_assigned_gt_inds]
else:
assigned_shift = order_index[pos_inds, pos_assigned_gt_inds]
pts_targets = pts_pred.new_zeros((pts_pred.size(0),
pts_pred.size(1), pts_pred.size(2)))
pts_weights = torch.zeros_like(pts_targets)
pts_weights[pos_inds] = 1.0
# DETR
bbox_targets[pos_inds] = pos_gt_bboxes
pts_targets[pos_inds] = gt_shifts_pts[pos_assigned_gt_inds, assigned_shift,:,:]
return labels, label_weights, bbox_targets, bbox_weights, pts_targets, pts_weights, pos_inds, neg_inds
def _get_target_single(self,
cls_score,
bbox_pred,
pts_pred,
gt_labels,
gt_bboxes,
gt_shifts_pts,
gt_bboxes_ignore=None):
""""Compute regression and classification targets for one image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_score (Tensor): Box score logits from a single decoder layer
for one image. Shape [num_query, cls_out_channels].
bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
for one image, with normalized coordinate (cx, cy, w, h) and
shape [num_query, 4].
gt_bboxes (Tensor): Ground truth bboxes for one image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (Tensor): Ground truth class indices for one image
with shape (num_gts, ).
gt_bboxes_ignore (Tensor, optional): Bounding boxes
which can be ignored. Default None.
Returns:
tuple[Tensor]: a tuple containing the following for one image.
- labels (Tensor): Labels of each image.
- label_weights (Tensor]): Label weights of each image.
- bbox_targets (Tensor): BBox targets of each image.
- bbox_weights (Tensor): BBox weights of each image.
- pos_inds (Tensor): Sampled positive indices for each image.
- neg_inds (Tensor): Sampled negative indices for each image.
"""
# num_bboxes = bbox_pred.size(0)
# # assigner and sampler
# gt_c = gt_bboxes.shape[-1]
assign_result, order_index = self.assigner.assign(bbox_pred, cls_score, pts_pred, gt_bboxes, gt_labels, gt_shifts_pts, gt_bboxes_ignore)
sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes[0])
labels, label_weights, bbox_targets, bbox_weights, pts_targets, pts_weights, pos_inds, neg_inds = self.get_label_result(sampling_result, gt_bboxes[0], gt_labels[0], bbox_pred, order_index, pts_pred, gt_shifts_pts[0])
# sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes)
# labels, label_weights, bbox_targets, bbox_weights, pts_targets, pts_weights, pos_inds, neg_inds = self.get_label_result(sampling_result, gt_bboxes, gt_labels, bbox_pred, order_index, pts_pred, gt_shifts_pts)
return (labels, label_weights, bbox_targets, bbox_weights,
pts_targets, pts_weights,
pos_inds, neg_inds)
def get_targets(self,
cls_scores_list,
bbox_preds_list,
pts_preds_list,
gt_bboxes_list,
gt_labels_list,
gt_shifts_pts_list,
gt_bboxes_ignore_list=None):
""""Compute regression and classification targets for a batch image.
Outputs from a single decoder layer of a single feature level are used.
Args:
cls_scores_list (list[Tensor]): Box score logits from a single
decoder layer for each image with shape [num_query,
cls_out_channels].
bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
decoder layer for each image, with normalized coordinate
(cx, cy, w, h) and shape [num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
tuple: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels for all images.
- label_weights_list (list[Tensor]): Label weights for all \
images.
- bbox_targets_list (list[Tensor]): BBox targets for all \
images.
- bbox_weights_list (list[Tensor]): BBox weights for all \
images.
- num_total_pos (int): Number of positive samples in all \
images.
- num_total_neg (int): Number of negative samples in all \
images.
"""
assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.'
num_imgs = len(cls_scores_list)
gt_bboxes_ignore_list = [gt_bboxes_ignore_list for _ in range(num_imgs)]
(labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, pts_targets_list, pts_weights_list,
pos_inds_list, neg_inds_list) = multi_apply(
self._get_target_single, cls_scores_list, bbox_preds_list,pts_preds_list,
gt_labels_list, gt_bboxes_list, gt_shifts_pts_list, gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, pts_targets_list, pts_weights_list,
num_total_pos, num_total_neg)
def loss_single(self,
cls_scores,
bbox_preds,
pts_preds,
gt_bboxes_list,
gt_labels_list,
gt_shifts_pts_list,
gt_bboxes_ignore_list=None):
""""Loss function for outputs from a single decoder layer of a single
feature level.
Args:
cls_scores (Tensor): Box score logits from a single decoder layer
for all images. Shape [bs, num_query, cls_out_channels].
bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
for all images, with normalized coordinate (cx, cy, w, h) and
shape [bs, num_query, 4].
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
gt_pts_list (list[Tensor]): Ground truth pts for each image
with shape (num_gts, fixed_num, 2) in [x,y] format.
gt_bboxes_ignore_list (list[Tensor], optional): Bounding
boxes which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components for outputs from
a single decoder layer.
"""
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
pts_preds_list = [pts_preds[i] for i in range(num_imgs)]
# import pdb;pdb.set_trace()
cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, pts_preds_list,
gt_bboxes_list, gt_labels_list, gt_shifts_pts_list,
gt_bboxes_ignore_list)
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
pts_targets_list, pts_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
# import pdb;pdb.set_trace()
labels = torch.cat(labels_list, 0)
label_weights = torch.cat(label_weights_list, 0)
bbox_targets = torch.cat(bbox_targets_list, 0)
bbox_weights = torch.cat(bbox_weights_list, 0)
pts_targets = torch.cat(pts_targets_list, 0)
pts_weights = torch.cat(pts_weights_list, 0)
# classification loss
cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean(
cls_scores.new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1)
loss_cls = self.loss_cls(
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
# import pdb;pdb.set_trace()
# regression L1 loss
bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))
normalized_bbox_targets = normalize_2d_bbox(bbox_targets, self.pc_range)
# normalized_bbox_targets = bbox_targets
isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)
bbox_weights = bbox_weights * self.code_weights
loss_bbox = self.loss_bbox(
bbox_preds[isnotnan, :4], normalized_bbox_targets[isnotnan,
:4], bbox_weights[isnotnan, :4],
avg_factor=num_total_pos)
# regression pts CD loss
# pts_preds = pts_preds
# import pdb;pdb.set_trace()
# num_samples, num_order, num_pts, num_coords
normalized_pts_targets = normalize_2d_pts(pts_targets, self.pc_range) if not self.z_cfg['gt_z_flag'] \
else normalize_3d_pts(pts_targets, self.pc_range)
# num_samples, num_pts, num_coords
pts_preds = pts_preds.reshape(-1, pts_preds.size(-2),pts_preds.size(-1))
if self.num_pts_per_vec != self.num_pts_per_gt_vec:
pts_preds = pts_preds.permute(0,2,1)
pts_preds = F.interpolate(pts_preds, size=(self.num_pts_per_gt_vec), mode='linear',
align_corners=True)
pts_preds = pts_preds.permute(0,2,1).contiguous()
# import pdb;pdb.set_trace()
loss_pts = self.loss_pts(
pts_preds[isnotnan,:,:], normalized_pts_targets[isnotnan,
:,:],
pts_weights[isnotnan,:,:],
avg_factor=num_total_pos)
dir_weights = pts_weights[:, :-self.dir_interval,0]
denormed_pts_preds = denormalize_2d_pts(pts_preds, self.pc_range) if not self.z_cfg['gt_z_flag'] \
else denormalize_3d_pts(pts_preds, self.pc_range)
denormed_pts_preds_dir = denormed_pts_preds[:,self.dir_interval:,:] - denormed_pts_preds[:,:-self.dir_interval,:]
pts_targets_dir = pts_targets[:, self.dir_interval:,:] - pts_targets[:,:-self.dir_interval,:]
# dir_weights = pts_weights[:, indice,:-1,0]
# import pdb;pdb.set_trace()
loss_dir = self.loss_dir(
denormed_pts_preds_dir[isnotnan,:,:], pts_targets_dir[isnotnan,
:,:],
dir_weights[isnotnan,:],
avg_factor=num_total_pos)
bboxes = denormalize_2d_bbox(bbox_preds, self.pc_range)
# regression IoU loss, defaultly GIoU loss
loss_iou = self.loss_iou(
bboxes[isnotnan, :4], bbox_targets[isnotnan, :4], bbox_weights[isnotnan, :4],
avg_factor=num_total_pos)
if digit_version(TORCH_VERSION) >= digit_version('1.8'):
loss_cls = torch.nan_to_num(loss_cls)
loss_bbox = torch.nan_to_num(loss_bbox)
loss_iou = torch.nan_to_num(loss_iou)
loss_pts = torch.nan_to_num(loss_pts)
loss_dir = torch.nan_to_num(loss_dir)
return loss_cls, loss_bbox, loss_iou, loss_pts, loss_dir
import torch
def pad_to_static_list(self, tensors, pad_value=0, device=None):
# max_len = max(t.size(0) for t in tensors)
max_len = 200
results = []
for t in tensors:
pad_shape = (max_len,) + t.shape[1:]
out = torch.full(pad_shape, pad_value, device=device, dtype=t.dtype)
mask = torch.zeros(max_len, dtype=torch.bool, device=device)
length = t.size(0)
out[:length, ...] = t
mask[:length] = 1
results.append((out, mask, length))
return results
@force_fp32(apply_to=('preds_dicts'))
def loss(self,
gt_bboxes_list,
gt_labels_list,
gt_seg_mask,
gt_pv_seg_mask,
preds_dicts,
gt_bboxes_ignore=None,
img_metas=None):
""""Loss function.
Args:
gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (num_gts, ).
preds_dicts:
all_cls_scores (Tensor): Classification score of all
decoder layers, has shape
[nb_dec, bs, num_query, cls_out_channels].
all_bbox_preds (Tensor): Sigmoid regression
outputs of all decode layers. Each is a 4D-tensor with
normalized coordinate format (cx, cy, w, h) and shape
[nb_dec, bs, num_query, 4].
enc_cls_scores (Tensor): Classification scores of
points on encode feature map , has shape
(N, h*w, num_classes). Only be passed when as_two_stage is
True, otherwise is None.
enc_bbox_preds (Tensor): Regression results of each points
on the encode feature map, has shape (N, h*w, 4). Only be
passed when as_two_stage is True, otherwise is None.
gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
which can be ignored for each image. Default None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert gt_bboxes_ignore is None, \
f'{self.__class__.__name__} only supports ' \
f'for gt_bboxes_ignore setting to None.'
gt_vecs_list = copy.deepcopy(gt_bboxes_list)
all_cls_scores = preds_dicts['all_cls_scores']
all_bbox_preds = preds_dicts['all_bbox_preds']
all_pts_preds = preds_dicts['all_pts_preds']
enc_cls_scores = preds_dicts['enc_cls_scores']
enc_bbox_preds = preds_dicts['enc_bbox_preds']
enc_pts_preds = preds_dicts['enc_pts_preds']
num_dec_layers = len(all_cls_scores)
device = gt_labels_list[0].device
gt_bboxes_list = [gt_bboxes.bbox.to(device) for gt_bboxes in gt_vecs_list]
# gt_pts_list = [gt_bboxes.fixed_num_sampled_points.to(device) for gt_bboxes in gt_vecs_list]
# if self.gt_shift_pts_pattern == 'v0':
# gt_shifts_pts_list = [
# gt_bboxes.shift_fixed_num_sampled_points.to(device) for gt_bboxes in gt_vecs_list]
# elif self.gt_shift_pts_pattern == 'v1':
# gt_shifts_pts_list = [
# gt_bboxes.shift_fixed_num_sampled_points_v1.to(device) for gt_bboxes in gt_vecs_list]
# elif self.gt_shift_pts_pattern == 'v2':
# gt_shifts_pts_list = [
# gt_bboxes.shift_fixed_num_sampled_points_v2.to(device) for gt_bboxes in gt_vecs_list]
# elif self.gt_shift_pts_pattern == 'v3':
# gt_shifts_pts_list = [
# gt_bboxes.shift_fixed_num_sampled_points_v3.to(device) for gt_bboxes in gt_vecs_list]
# elif self.gt_shift_pts_pattern == 'v4':
# gt_shifts_pts_list = [
# gt_bboxes.shift_fixed_num_sampled_points_v4.to(device) for gt_bboxes in gt_vecs_list]
# else:
# raise NotImplementedError
# all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
# all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
# all_gt_shifts_pts_list = [gt_shifts_pts_list for _ in range(num_dec_layers)]
# all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)]
# # all_gt_pts_list = [gt_pts_list for _ in range(num_dec_layers)]
if self.gt_shift_pts_pattern == 'v0':
gt_shifts_pts_list = [
gt_bboxes.shift_fixed_num_sampled_points for gt_bboxes in gt_vecs_list]
elif self.gt_shift_pts_pattern == 'v1':
gt_shifts_pts_list = [
gt_bboxes.shift_fixed_num_sampled_points_v1 for gt_bboxes in gt_vecs_list]
elif self.gt_shift_pts_pattern == 'v2':
gt_shifts_pts_list = [
gt_bboxes.shift_fixed_num_sampled_points_v2 for gt_bboxes in gt_vecs_list]
elif self.gt_shift_pts_pattern == 'v3':
gt_shifts_pts_list = [
gt_bboxes.shift_fixed_num_sampled_points_v3 for gt_bboxes in gt_vecs_list]
elif self.gt_shift_pts_pattern == 'v4':
gt_shifts_pts_list = [
gt_bboxes.shift_fixed_num_sampled_points_v4 for gt_bboxes in gt_vecs_list]
else:
raise NotImplementedError
all_gt_bboxes = self.pad_to_static_list(gt_bboxes_list, device=device)
all_shifts_pts = self.pad_to_static_list(gt_shifts_pts_list, device=device)
all_gt_labels = self.pad_to_static_list(gt_labels_list, device=device)
all_gt_bboxes_list = [all_gt_bboxes for _ in range(num_dec_layers)]
all_gt_shifts_pts_list = [all_shifts_pts for _ in range(num_dec_layers)]
all_gt_labels_list = [all_gt_labels for _ in range(num_dec_layers)]
all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)]
# all_gt_pts_list = [gt_pts_list for _ in range(num_dec_layers)]
losses_cls, losses_bbox, losses_iou, losses_pts, losses_dir = multi_apply(
self.loss_single, all_cls_scores, all_bbox_preds, all_pts_preds,
all_gt_bboxes_list, all_gt_labels_list, all_gt_shifts_pts_list,
all_gt_bboxes_ignore_list)
loss_dict = dict()
if self.aux_seg['use_aux_seg']:
# import ipdb;ipdb.set_trace()
if self.aux_seg['bev_seg']:
if preds_dicts['seg'] is not None:
seg_output = preds_dicts['seg']
num_imgs = seg_output.size(0)
seg_gt = torch.stack([gt_seg_mask[i] for i in range(num_imgs)],dim=0)
loss_seg = self.loss_seg(seg_output, seg_gt.float())
loss_dict['loss_seg'] = loss_seg
if self.aux_seg['pv_seg']:
# import ipdb;ipdb.set_trace()
if preds_dicts['pv_seg'] is not None:
pv_seg_output = preds_dicts['pv_seg']
num_imgs = pv_seg_output.size(0)
pv_seg_gt = torch.stack([gt_pv_seg_mask[i] for i in range(num_imgs)],dim=0)
loss_pv_seg = self.loss_pv_seg(pv_seg_output, pv_seg_gt.float())
loss_dict['loss_pv_seg'] = loss_pv_seg
# loss of proposal generated from encode feature map.
if enc_cls_scores is not None:
binary_labels_list = [
torch.zeros_like(gt_labels_list[i])
for i in range(len(all_gt_labels_list))
]
# TODO bug here
enc_loss_cls, enc_losses_bbox, enc_losses_iou, enc_losses_pts, enc_losses_dir = \
self.loss_single(enc_cls_scores, enc_bbox_preds, enc_pts_preds,
gt_bboxes_list, binary_labels_list, gt_pts_list,gt_bboxes_ignore)
loss_dict['enc_loss_cls'] = enc_loss_cls
loss_dict['enc_loss_bbox'] = enc_losses_bbox
loss_dict['enc_losses_iou'] = enc_losses_iou
loss_dict['enc_losses_pts'] = enc_losses_pts
loss_dict['enc_losses_dir'] = enc_losses_dir
# loss from the last decoder layer
loss_dict['loss_cls'] = losses_cls[-1]
loss_dict['loss_bbox'] = losses_bbox[-1]
loss_dict['loss_iou'] = losses_iou[-1]
loss_dict['loss_pts'] = losses_pts[-1]
loss_dict['loss_dir'] = losses_dir[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_bbox_i, loss_iou_i, loss_pts_i, loss_dir_i in zip(losses_cls[:-1],
losses_bbox[:-1],
losses_iou[:-1],
losses_pts[:-1],
losses_dir[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
loss_dict[f'd{num_dec_layer}.loss_pts'] = loss_pts_i
loss_dict[f'd{num_dec_layer}.loss_dir'] = loss_dir_i
num_dec_layer += 1
return loss_dict
@force_fp32(apply_to=('preds_dicts'))
def get_bboxes(self, preds_dicts, img_metas, rescale=False):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
"""
# bboxes: xmin, ymin, xmax, ymax
preds_dicts = self.bbox_coder.decode(preds_dicts)
num_samples = len(preds_dicts)
ret_list = []
for i in range(num_samples):
preds = preds_dicts[i]
bboxes = preds['bboxes']
# bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
# code_size = bboxes.shape[-1]
# bboxes = img_metas[i]['box_type_3d'](bboxes, code_size)
scores = preds['scores']
labels = preds['labels']
pts = preds['pts']
ret_list.append([bboxes, scores, labels, pts])
return ret_list
from .maptr import MapTR
from .maptrv2 import MapTRv2
\ No newline at end of file
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models import DETECTORS
from mmdet3d.core import bbox3d2result
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from projects.mmdet3d_plugin.models.utils.grid_mask import GridMask
from mmcv.runner import force_fp32, auto_fp16
from mmdet3d.ops import Voxelization, DynamicScatter
from mmdet3d.models import builder
@DETECTORS.register_module()
class MapTR(MVXTwoStageDetector):
"""MapTR.
Args:
video_test_mode (bool): Decide whether to use temporal information during inference.
"""
def __init__(self,
use_grid_mask=False,
pts_voxel_layer=None,
pts_voxel_encoder=None,
pts_middle_encoder=None,
pts_fusion_layer=None,
img_backbone=None,
pts_backbone=None,
img_neck=None,
pts_neck=None,
pts_bbox_head=None,
img_roi_head=None,
img_rpn_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
video_test_mode=False,
modality='vision',
lidar_encoder=None,
):
super(MapTR,
self).__init__(pts_voxel_layer, pts_voxel_encoder,
pts_middle_encoder, pts_fusion_layer,
img_backbone, pts_backbone, img_neck, pts_neck,
pts_bbox_head, img_roi_head, img_rpn_head,
train_cfg, test_cfg, pretrained)
self.grid_mask = GridMask(
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7)
self.use_grid_mask = use_grid_mask
self.fp16_enabled = False
# temporal
self.video_test_mode = video_test_mode
self.prev_frame_info = {
'prev_bev': None,
'scene_token': None,
'prev_pos': 0,
'prev_angle': 0,
}
self.modality = modality
if self.modality == 'fusion' and lidar_encoder is not None :
if lidar_encoder["voxelize"].get("max_num_points", -1) > 0:
voxelize_module = Voxelization(**lidar_encoder["voxelize"])
else:
voxelize_module = DynamicScatter(**lidar_encoder["voxelize"])
self.lidar_modal_extractor = nn.ModuleDict(
{
"voxelize": voxelize_module,
"backbone": builder.build_middle_encoder(lidar_encoder["backbone"]),
}
)
self.voxelize_reduce = lidar_encoder.get("voxelize_reduce", True)
def extract_img_feat(self, img, img_metas, len_queue=None):
"""Extract features of images."""
B = img.size(0)
if img is not None:
# input_shape = img.shape[-2:]
# # update real input shape of each single img
# for img_meta in img_metas:
# img_meta.update(input_shape=input_shape)
if img.dim() == 5 and img.size(0) == 1:
img.squeeze_()
elif img.dim() == 5 and img.size(0) > 1:
B, N, C, H, W = img.size()
img = img.reshape(B * N, C, H, W)
if self.use_grid_mask:
img = self.grid_mask(img)
img_feats = self.img_backbone(img)
if isinstance(img_feats, dict):
img_feats = list(img_feats.values())
else:
return None
if self.with_img_neck:
img_feats = self.img_neck(img_feats)
img_feats_reshaped = []
for img_feat in img_feats:
BN, C, H, W = img_feat.size()
if len_queue is not None:
img_feats_reshaped.append(img_feat.view(int(B/len_queue), len_queue, int(BN / B), C, H, W))
else:
img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
return img_feats_reshaped
@auto_fp16(apply_to=('img'), out_fp32=True)
def extract_feat(self, img, img_metas=None, len_queue=None):
"""Extract features from images and points."""
img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)
return img_feats
def forward_pts_train(self,
pts_feats,
lidar_feat,
gt_bboxes_3d,
gt_labels_3d,
img_metas,
gt_bboxes_ignore=None,
prev_bev=None):
"""Forward function'
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
img_metas (list[dict]): Meta information of samples.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
prev_bev (torch.Tensor, optional): BEV features of previous frame.
Returns:
dict: Losses of each branch.
"""
outs = self.pts_bbox_head(
pts_feats, lidar_feat, img_metas, prev_bev)
loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
losses = self.pts_bbox_head.loss(*loss_inputs, img_metas=img_metas)
return losses
def forward_dummy(self, img):
dummy_metas = None
return self.forward_test(img=img, img_metas=[[dummy_metas]])
def forward(self, return_loss=True, **kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, img and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
"""
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def obtain_history_bev(self, imgs_queue, img_metas_list):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""
self.eval()
with torch.no_grad():
prev_bev = None
bs, len_queue, num_cams, C, H, W = imgs_queue.shape
imgs_queue = imgs_queue.reshape(bs*len_queue, num_cams, C, H, W)
img_feats_list = self.extract_feat(img=imgs_queue, len_queue=len_queue)
for i in range(len_queue):
img_metas = [each[i] for each in img_metas_list]
# img_feats = self.extract_feat(img=img, img_metas=img_metas)
img_feats = [each_scale[:, i] for each_scale in img_feats_list]
prev_bev = self.pts_bbox_head(
img_feats, img_metas, prev_bev, only_bev=True)
self.train()
return prev_bev
@torch.no_grad()
@force_fp32()
def voxelize(self, points):
feats, coords, sizes = [], [], []
for k, res in enumerate(points):
ret = self.lidar_modal_extractor["voxelize"](res)
if len(ret) == 3:
# hard voxelize
f, c, n = ret
else:
assert len(ret) == 2
f, c = ret
n = None
feats.append(f)
coords.append(F.pad(c, (1, 0), mode="constant", value=k))
if n is not None:
sizes.append(n)
feats = torch.cat(feats, dim=0)
coords = torch.cat(coords, dim=0)
if len(sizes) > 0:
sizes = torch.cat(sizes, dim=0)
if self.voxelize_reduce:
feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
-1, 1
)
feats = feats.contiguous()
return feats, coords, sizes
@auto_fp16(apply_to=('points'), out_fp32=True)
def extract_lidar_feat(self,points):
feats, coords, sizes = self.voxelize(points)
# voxel_features = self.lidar_modal_extractor["voxel_encoder"](feats, sizes, coords)
batch_size = coords[-1, 0] + 1
lidar_feat = self.lidar_modal_extractor["backbone"](feats, coords, batch_size, sizes=sizes)
return lidar_feat
# @auto_fp16(apply_to=('img', 'points'))
@force_fp32(apply_to=('img','points','prev_bev'))
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img=None,
proposals=None,
gt_bboxes_ignore=None,
img_depth=None,
img_mask=None,
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
lidar_feat = None
if self.modality == 'fusion':
lidar_feat = self.extract_lidar_feat(points)
len_queue = img.size(1)
prev_img = img[:, :-1, ...]
img = img[:, -1, ...]
prev_img_metas = copy.deepcopy(img_metas)
# prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
# import pdb;pdb.set_trace()
prev_bev = self.obtain_history_bev(prev_img, prev_img_metas) if len_queue>1 else None
img_metas = [each[len_queue-1] for each in img_metas]
img_feats = self.extract_feat(img=img, img_metas=img_metas)
losses = dict()
losses_pts = self.forward_pts_train(img_feats, lidar_feat, gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore, prev_bev)
losses.update(losses_pts)
return losses
def forward_test(self, img_metas, img=None,points=None, **kwargs):
for var, name in [(img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
img = [img] if img is None else img
points = [points] if points is None else points
if img_metas[0][0]['scene_token'] != self.prev_frame_info['scene_token']:
# the first sample of each scene is truncated
self.prev_frame_info['prev_bev'] = None
# update idx
self.prev_frame_info['scene_token'] = img_metas[0][0]['scene_token']
# do not use temporal information
if not self.video_test_mode:
self.prev_frame_info['prev_bev'] = None
# Get the delta of ego position and angle between two timestamps.
tmp_pos = copy.deepcopy(img_metas[0][0]['can_bus'][:3])
tmp_angle = copy.deepcopy(img_metas[0][0]['can_bus'][-1])
if self.prev_frame_info['prev_bev'] is not None:
img_metas[0][0]['can_bus'][:3] -= self.prev_frame_info['prev_pos']
img_metas[0][0]['can_bus'][-1] -= self.prev_frame_info['prev_angle']
else:
img_metas[0][0]['can_bus'][-1] = 0
img_metas[0][0]['can_bus'][:3] = 0
new_prev_bev, bbox_results = self.simple_test(
img_metas[0], img[0], points[0], prev_bev=self.prev_frame_info['prev_bev'], **kwargs)
# During inference, we save the BEV features and ego motion of each timestamp.
self.prev_frame_info['prev_pos'] = tmp_pos
self.prev_frame_info['prev_angle'] = tmp_angle
self.prev_frame_info['prev_bev'] = new_prev_bev
return bbox_results
def pred2result(self, bboxes, scores, labels, pts, attrs=None):
"""Convert detection results to a list of numpy arrays.
Args:
bboxes (torch.Tensor): Bounding boxes with shape of (n, 5).
labels (torch.Tensor): Labels with shape of (n, ).
scores (torch.Tensor): Scores with shape of (n, ).
attrs (torch.Tensor, optional): Attributes with shape of (n, ). \
Defaults to None.
Returns:
dict[str, torch.Tensor]: Bounding box results in cpu mode.
- boxes_3d (torch.Tensor): 3D boxes.
- scores (torch.Tensor): Prediction scores.
- labels_3d (torch.Tensor): Box labels.
- attrs_3d (torch.Tensor, optional): Box attributes.
"""
result_dict = dict(
boxes_3d=bboxes.to('cpu'),
scores_3d=scores.cpu(),
labels_3d=labels.cpu(),
pts_3d=pts.to('cpu'))
if attrs is not None:
result_dict['attrs_3d'] = attrs.cpu()
return result_dict
def simple_test_pts(self, x, lidar_feat, img_metas, prev_bev=None, rescale=False):
"""Test function"""
outs = self.pts_bbox_head(x, lidar_feat, img_metas, prev_bev=prev_bev)
bbox_list = self.pts_bbox_head.get_bboxes(
outs, img_metas, rescale=rescale)
bbox_results = [
self.pred2result(bboxes, scores, labels, pts)
for bboxes, scores, labels, pts in bbox_list
]
# import pdb;pdb.set_trace()
return outs['bev_embed'], bbox_results
def simple_test(self, img_metas, img=None, points=None, prev_bev=None, rescale=False, **kwargs):
"""Test function without augmentaiton."""
lidar_feat = None
if self.modality =='fusion':
lidar_feat = self.extract_lidar_feat(points)
img_feats = self.extract_feat(img=img, img_metas=img_metas)
bbox_list = [dict() for i in range(len(img_metas))]
new_prev_bev, bbox_pts = self.simple_test_pts(
img_feats, lidar_feat, img_metas, prev_bev, rescale=rescale)
for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
result_dict['pts_bbox'] = pts_bbox
return new_prev_bev, bbox_list
@DETECTORS.register_module()
class MapTR_fp16(MapTR):
"""
The default version BEVFormer currently can not support FP16.
We provide this version to resolve this issue.
"""
# @auto_fp16(apply_to=('img', 'prev_bev', 'points'))
@force_fp32(apply_to=('img','points','prev_bev'))
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img=None,
proposals=None,
gt_bboxes_ignore=None,
img_depth=None,
img_mask=None,
prev_bev=None,
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats = self.extract_feat(img=img, img_metas=img_metas)
# import pdb;pdb.set_trace()
losses = dict()
losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore, prev_bev=prev_bev)
losses.update(losses_pts)
return losses
def val_step(self, data, optimizer):
"""
In BEVFormer_fp16, we use this `val_step` function to inference the `prev_pev`.
This is not the standard function of `val_step`.
"""
img = data['img']
img_metas = data['img_metas']
img_feats = self.extract_feat(img=img, img_metas=img_metas)
prev_bev = data.get('prev_bev', None)
prev_bev = self.pts_bbox_head(img_feats, img_metas, prev_bev=prev_bev, only_bev=True)
return prev_bev
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models import DETECTORS
from mmdet3d.core import bbox3d2result
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from projects.mmdet3d_plugin.models.utils.grid_mask import GridMask
from mmcv.runner import force_fp32, auto_fp16
from mmdet3d.ops import Voxelization, DynamicScatter
from mmdet3d.models import builder
from mmcv.utils import TORCH_VERSION, digit_version
@DETECTORS.register_module()
class MapTRv2(MVXTwoStageDetector):
"""MapTR.
Args:
video_test_mode (bool): Decide whether to use temporal information during inference.
"""
def __init__(self,
use_grid_mask=False,
pts_voxel_layer=None,
pts_voxel_encoder=None,
pts_middle_encoder=None,
pts_fusion_layer=None,
img_backbone=None,
pts_backbone=None,
img_neck=None,
pts_neck=None,
pts_bbox_head=None,
img_roi_head=None,
img_rpn_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
video_test_mode=False,
modality='vision',
lidar_encoder=None,
):
super(MapTRv2,
self).__init__(pts_voxel_layer, pts_voxel_encoder,
pts_middle_encoder, pts_fusion_layer,
img_backbone, pts_backbone, img_neck, pts_neck,
pts_bbox_head, img_roi_head, img_rpn_head,
train_cfg, test_cfg, pretrained)
self.grid_mask = GridMask(
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7)
self.use_grid_mask = use_grid_mask
self.fp16_enabled = False
# temporal
self.video_test_mode = video_test_mode
self.prev_frame_info = {
'prev_bev': None,
'scene_token': None,
'prev_pos': 0,
'prev_angle': 0,
}
self.modality = modality
if self.modality == 'fusion' and lidar_encoder is not None :
if lidar_encoder["voxelize"].get("max_num_points", -1) > 0:
voxelize_module = Voxelization(**lidar_encoder["voxelize"])
else:
voxelize_module = DynamicScatter(**lidar_encoder["voxelize"])
self.lidar_modal_extractor = nn.ModuleDict(
{
"voxelize": voxelize_module,
"backbone": builder.build_middle_encoder(lidar_encoder["backbone"]),
}
)
self.voxelize_reduce = lidar_encoder.get("voxelize_reduce", True)
#@torch.compile(mode="max-autotune-no-cudagraphs")
def extract_img_feat(self, img, img_metas, len_queue=None):
"""Extract features of images."""
B = img.size(0)
if img is not None:
# input_shape = img.shape[-2:]
# # update real input shape of each single img
# for img_meta in img_metas:
# img_meta.update(input_shape=input_shape)
if img.dim() == 5 and img.size(0) == 1:
img.squeeze_()
elif img.dim() == 5 and img.size(0) > 1:
B, N, C, H, W = img.size()
img = img.reshape(B * N, C, H, W)
if self.use_grid_mask:
img = self.grid_mask(img)
img_feats = self.img_backbone(img)
if isinstance(img_feats, dict):
img_feats = list(img_feats.values())
else:
return None
if self.with_img_neck:
img_feats = self.img_neck(img_feats)
img_feats_reshaped = []
for img_feat in img_feats:
BN, C, H, W = img_feat.size()
if len_queue is not None:
img_feats_reshaped.append(img_feat.view(int(B/len_queue), len_queue, int(BN / B), C, H, W))
else:
img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
return img_feats_reshaped
@auto_fp16(apply_to=('img'), out_fp32=True)
def extract_feat(self, img, img_metas=None, len_queue=None):
"""Extract features from images and points."""
img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)
return img_feats
def forward_pts_train(self,
pts_feats,
lidar_feat,
gt_bboxes_3d,
gt_labels_3d,
img_metas,
gt_bboxes_ignore=None,
prev_bev=None,
gt_depth=None,
gt_seg_mask=None,
gt_pv_seg_mask=None,):
"""Forward function'
Args:
pts_feats (list[torch.Tensor]): Features of point cloud branch
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
boxes for each sample.
gt_labels_3d (list[torch.Tensor]): Ground truth labels for
boxes of each sampole
img_metas (list[dict]): Meta information of samples.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
boxes to be ignored. Defaults to None.
prev_bev (torch.Tensor, optional): BEV features of previous frame.
Returns:
dict: Losses of each branch.
"""
outs = self.pts_bbox_head(
pts_feats, lidar_feat, img_metas, prev_bev)
depth = outs.pop('depth')
losses = dict()
# calculate depth loss
if gt_depth is not None:
loss_depth = self.pts_bbox_head.transformer.encoder.get_depth_loss(gt_depth, depth)
if digit_version(TORCH_VERSION) >= digit_version('1.8'):
loss_depth = torch.nan_to_num(loss_depth)
losses.update(loss_depth=loss_depth)
loss_inputs = [gt_bboxes_3d, gt_labels_3d, gt_seg_mask, gt_pv_seg_mask, outs]
losses_pts = self.pts_bbox_head.loss(*loss_inputs, img_metas=img_metas)
losses.update(losses_pts)
# import ipdb;ipdb.set_trace()
k_one2many = self.pts_bbox_head.k_one2many
multi_gt_bboxes_3d = copy.deepcopy(gt_bboxes_3d)
multi_gt_labels_3d = copy.deepcopy(gt_labels_3d)
for i, (each_gt_bboxes_3d, each_gt_labels_3d) in enumerate(zip(multi_gt_bboxes_3d, multi_gt_labels_3d)):
each_gt_bboxes_3d.instance_list = each_gt_bboxes_3d.instance_list * k_one2many
each_gt_bboxes_3d.instance_labels = each_gt_bboxes_3d.instance_labels * k_one2many
multi_gt_labels_3d[i] = each_gt_labels_3d.repeat(k_one2many)
# import ipdb;ipdb.set_trace()
one2many_outs = outs['one2many_outs']
loss_one2many_inputs = [multi_gt_bboxes_3d, multi_gt_labels_3d, gt_seg_mask, gt_pv_seg_mask, one2many_outs]
loss_dict_one2many = self.pts_bbox_head.loss(*loss_one2many_inputs, img_metas=img_metas)
lambda_one2many = self.pts_bbox_head.lambda_one2many
for key, value in loss_dict_one2many.items():
if key + "_one2many" in losses.keys():
losses[key + "_one2many"] += value * lambda_one2many
else:
losses[key + "_one2many"] = value * lambda_one2many
# import ipdb;ipdb.set_trace()
return losses
def forward_dummy(self, img):
dummy_metas = None
return self.forward_test(img=img, img_metas=[[dummy_metas]])
def forward(self, return_loss=True, **kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss=True.
Note this setting will change the expected inputs. When
`return_loss=True`, img and img_metas are single-nested (i.e.
torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
img_metas should be double nested (i.e. list[torch.Tensor],
list[list[dict]]), with the outer list indicating test time
augmentations.
"""
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def obtain_history_bev(self, imgs_queue, img_metas_list):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""
self.eval()
with torch.no_grad():
prev_bev = None
bs, len_queue, num_cams, C, H, W = imgs_queue.shape
imgs_queue = imgs_queue.reshape(bs*len_queue, num_cams, C, H, W)
img_feats_list = self.extract_feat(img=imgs_queue, len_queue=len_queue)
for i in range(len_queue):
img_metas = [each[i] for each in img_metas_list]
# img_feats = self.extract_feat(img=img, img_metas=img_metas)
img_feats = [each_scale[:, i] for each_scale in img_feats_list]
prev_bev = self.pts_bbox_head(
img_feats, img_metas, prev_bev, only_bev=True)
self.train()
return prev_bev
@torch.no_grad()
@force_fp32()
def voxelize(self, points):
feats, coords, sizes = [], [], []
for k, res in enumerate(points):
ret = self.lidar_modal_extractor["voxelize"](res)
if len(ret) == 3:
# hard voxelize
f, c, n = ret
else:
assert len(ret) == 2
f, c = ret
n = None
feats.append(f)
coords.append(F.pad(c, (1, 0), mode="constant", value=k))
if n is not None:
sizes.append(n)
feats = torch.cat(feats, dim=0)
coords = torch.cat(coords, dim=0)
if len(sizes) > 0:
sizes = torch.cat(sizes, dim=0)
if self.voxelize_reduce:
feats = feats.sum(dim=1, keepdim=False) / sizes.type_as(feats).view(
-1, 1
)
feats = feats.contiguous()
return feats, coords, sizes
@auto_fp16(apply_to=('points'), out_fp32=True)
def extract_lidar_feat(self,points):
feats, coords, sizes = self.voxelize(points)
# voxel_features = self.lidar_modal_extractor["voxel_encoder"](feats, sizes, coords)
batch_size = coords[-1, 0] + 1
lidar_feat = self.lidar_modal_extractor["backbone"](feats, coords, batch_size, sizes=sizes)
return lidar_feat
# @auto_fp16(apply_to=('img', 'points'))
@force_fp32(apply_to=('img','points','prev_bev'))
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img=None,
proposals=None,
gt_bboxes_ignore=None,
img_depth=None,
img_mask=None,
gt_depth=None,
gt_seg_mask=None,
gt_pv_seg_mask=None,
):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
lidar_feat = None
if self.modality == 'fusion':
lidar_feat = self.extract_lidar_feat(points)
len_queue = img.size(1)
prev_img = img[:, :-1, ...]
img = img[:, -1, ...]
prev_img_metas = copy.deepcopy(img_metas)
# prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
# import pdb;pdb.set_trace()
prev_bev = self.obtain_history_bev(prev_img, prev_img_metas) if len_queue>1 else None
img_metas = [each[len_queue-1] for each in img_metas]
img_feats = self.extract_feat(img=img, img_metas=img_metas)
losses = dict()
losses_pts = self.forward_pts_train(img_feats, lidar_feat, gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore, prev_bev, gt_depth,gt_seg_mask,gt_pv_seg_mask)
losses.update(losses_pts)
return losses
def forward_test(self, img_metas, img=None,points=None, **kwargs):
for var, name in [(img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
img = [img] if img is None else img
points = [points] if points is None else points
if img_metas[0][0]['scene_token'] != self.prev_frame_info['scene_token']:
# the first sample of each scene is truncated
self.prev_frame_info['prev_bev'] = None
# update idx
self.prev_frame_info['scene_token'] = img_metas[0][0]['scene_token']
# do not use temporal information
if not self.video_test_mode:
self.prev_frame_info['prev_bev'] = None
# Get the delta of ego position and angle between two timestamps.
tmp_pos = copy.deepcopy(img_metas[0][0]['can_bus'][:3])
tmp_angle = copy.deepcopy(img_metas[0][0]['can_bus'][-1])
if self.prev_frame_info['prev_bev'] is not None:
img_metas[0][0]['can_bus'][:3] -= self.prev_frame_info['prev_pos']
img_metas[0][0]['can_bus'][-1] -= self.prev_frame_info['prev_angle']
else:
img_metas[0][0]['can_bus'][-1] = 0
img_metas[0][0]['can_bus'][:3] = 0
new_prev_bev, bbox_results = self.simple_test(
img_metas[0], img[0], points[0], prev_bev=self.prev_frame_info['prev_bev'], **kwargs)
# During inference, we save the BEV features and ego motion of each timestamp.
self.prev_frame_info['prev_pos'] = tmp_pos
self.prev_frame_info['prev_angle'] = tmp_angle
self.prev_frame_info['prev_bev'] = new_prev_bev
return bbox_results
def pred2result(self, bboxes, scores, labels, pts, attrs=None):
"""Convert detection results to a list of numpy arrays.
Args:
bboxes (torch.Tensor): Bounding boxes with shape of (n, 5).
labels (torch.Tensor): Labels with shape of (n, ).
scores (torch.Tensor): Scores with shape of (n, ).
attrs (torch.Tensor, optional): Attributes with shape of (n, ). \
Defaults to None.
Returns:
dict[str, torch.Tensor]: Bounding box results in cpu mode.
- boxes_3d (torch.Tensor): 3D boxes.
- scores (torch.Tensor): Prediction scores.
- labels_3d (torch.Tensor): Box labels.
- attrs_3d (torch.Tensor, optional): Box attributes.
"""
result_dict = dict(
boxes_3d=bboxes.to('cpu'),
scores_3d=scores.cpu(),
labels_3d=labels.cpu(),
pts_3d=pts.to('cpu'))
if attrs is not None:
result_dict['attrs_3d'] = attrs.cpu()
return result_dict
def simple_test_pts(self, x, lidar_feat, img_metas, prev_bev=None, rescale=False):
"""Test function"""
outs = self.pts_bbox_head(x, lidar_feat, img_metas, prev_bev=prev_bev)
bbox_list = self.pts_bbox_head.get_bboxes(
outs, img_metas, rescale=rescale)
bbox_results = [
self.pred2result(bboxes, scores, labels, pts)
for bboxes, scores, labels, pts in bbox_list
]
# import pdb;pdb.set_trace()
return outs['bev_embed'], bbox_results
def simple_test(self, img_metas, img=None, points=None, prev_bev=None, rescale=False, **kwargs):
"""Test function without augmentaiton."""
lidar_feat = None
if self.modality =='fusion':
lidar_feat = self.extract_lidar_feat(points)
img_feats = self.extract_feat(img=img, img_metas=img_metas)
bbox_list = [dict() for i in range(len(img_metas))]
new_prev_bev, bbox_pts = self.simple_test_pts(
img_feats, lidar_feat, img_metas, prev_bev, rescale=rescale)
for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
result_dict['pts_bbox'] = pts_bbox
return new_prev_bev, bbox_list
from .map_loss import MyChamferDistance
from .map_loss import MyChamferDistanceCost
from .map_loss import OrderedPtsL1Cost, PtsL1Cost
from .map_loss import OrderedPtsL1Loss, PtsL1Loss
from .map_loss import OrderedPtsSmoothL1Cost, OrderedPtsL1Loss
from .map_loss import PtsDirCosLoss
from .simple_loss import SimpleLoss
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn as nn
from torch.nn.functional import l1_loss, mse_loss, smooth_l1_loss
from mmdet.models.builder import LOSSES
from mmdet.models import weighted_loss
import mmcv
import torch.nn.functional as F
from mmdet.core.bbox.match_costs.builder import MATCH_COST
import functools
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
@mmcv.jit(derivate=True, coderize=True)
def custom_weight_dir_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): num_sample, num_dir
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Average factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
raise ValueError('avg_factor should not be none for OrderedPtsL1Loss')
# loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# import pdb;pdb.set_trace()
# loss = loss.permute(1,0,2,3).contiguous()
loss = loss.sum()
loss = loss / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
@mmcv.jit(derivate=True, coderize=True)
def custom_weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): num_sample, num_order, num_pts, num_coords
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Average factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
raise ValueError('avg_factor should not be none for OrderedPtsL1Loss')
# loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
# import pdb;pdb.set_trace()
loss = loss.permute(1,0,2,3).contiguous()
loss = loss.sum((1,2,3))
loss = loss / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def custom_weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@functools.wraps(loss_func)
def wrapper(pred,
target,
weight=None,
reduction='mean',
avg_factor=None,
**kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = custom_weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper
def custom_weighted_dir_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@functools.wraps(loss_func)
def wrapper(pred,
target,
weight=None,
reduction='mean',
avg_factor=None,
**kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = custom_weight_dir_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper
@mmcv.jit(derivate=True, coderize=True)
@custom_weighted_loss
def ordered_pts_smooth_l1_loss(pred, target):
"""L1 loss.
Args:
pred (torch.Tensor): shape [num_samples, num_pts, num_coords]
target (torch.Tensor): shape [num_samples, num_order, num_pts, num_coords]
Returns:
torch.Tensor: Calculated loss
"""
if target.numel() == 0:
return pred.sum() * 0
pred = pred.unsqueeze(1).repeat(1, target.size(1),1,1)
assert pred.size() == target.size()
loss =smooth_l1_loss(pred,target, reduction='none')
# import pdb;pdb.set_trace()
return loss
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def pts_l1_loss(pred, target):
"""L1 loss.
Args:
pred (torch.Tensor): shape [num_samples, num_pts, num_coords]
target (torch.Tensor): shape [num_samples, num_pts, num_coords]
Returns:
torch.Tensor: Calculated loss
"""
if target.numel() == 0:
return pred.sum() * 0
assert pred.size() == target.size()
loss = torch.abs(pred - target)
return loss
@mmcv.jit(derivate=True, coderize=True)
@custom_weighted_loss
def ordered_pts_l1_loss(pred, target):
"""L1 loss.
Args:
pred (torch.Tensor): shape [num_samples, num_pts, num_coords]
target (torch.Tensor): shape [num_samples, num_order, num_pts, num_coords]
Returns:
torch.Tensor: Calculated loss
"""
if target.numel() == 0:
return pred.sum() * 0
pred = pred.unsqueeze(1).repeat(1, target.size(1),1,1)
assert pred.size() == target.size()
loss = torch.abs(pred - target)
return loss
@mmcv.jit(derivate=True, coderize=True)
@custom_weighted_dir_loss
def pts_dir_cos_loss(pred, target):
""" Dir cosine similiarity loss
pred (torch.Tensor): shape [num_samples, num_dir, num_coords]
target (torch.Tensor): shape [num_samples, num_dir, num_coords]
"""
if target.numel() == 0:
return pred.sum() * 0
# import pdb;pdb.set_trace()
num_samples, num_dir, num_coords = pred.shape
loss_func = torch.nn.CosineEmbeddingLoss(reduction='none')
tgt_param = target.new_ones((num_samples, num_dir))
tgt_param = tgt_param.flatten(0)
loss = loss_func(pred.flatten(0,1), target.flatten(0,1), tgt_param)
loss = loss.view(num_samples, num_dir)
return loss
@LOSSES.register_module()
class OrderedPtsSmoothL1Loss(nn.Module):
"""L1 loss.
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super(OrderedPtsSmoothL1Loss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
# import pdb;pdb.set_trace()
loss_bbox = self.loss_weight * ordered_pts_smooth_l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
@LOSSES.register_module()
class PtsDirCosLoss(nn.Module):
"""L1 loss.
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super(PtsDirCosLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
# import pdb;pdb.set_trace()
loss_dir = self.loss_weight * pts_dir_cos_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_dir
@LOSSES.register_module()
class PtsL1Loss(nn.Module):
"""L1 loss.
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super(PtsL1Loss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
# import pdb;pdb.set_trace()
loss_bbox = self.loss_weight * pts_l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
@LOSSES.register_module()
class OrderedPtsL1Loss(nn.Module):
"""L1 loss.
Args:
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of loss.
"""
def __init__(self, reduction='mean', loss_weight=1.0):
super(OrderedPtsL1Loss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
# import pdb;pdb.set_trace()
loss_bbox = self.loss_weight * ordered_pts_l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
@MATCH_COST.register_module()
class OrderedPtsSmoothL1Cost(object):
"""OrderedPtsL1Cost.
Args:
weight (int | float, optional): loss_weight
"""
def __init__(self, weight=1.):
self.weight = weight
def __call__(self, bbox_pred, gt_bboxes):
"""
Args:
bbox_pred (Tensor): Predicted boxes with normalized coordinates
(x, y), which are all in range [0, 1]. Shape
[num_query, num_pts, 2].
gt_bboxes (Tensor): Ground truth boxes with normalized
coordinates (x,y).
Shape [num_gt, num_ordered, num_pts, 2].
Returns:
torch.Tensor: bbox_cost value with weight
"""
num_gts, num_orders, num_pts, num_coords = gt_bboxes.shape
# import pdb;pdb.set_trace()
bbox_pred = bbox_pred.view(bbox_pred.size(0),-1).unsqueeze(1).repeat(1,num_gts*num_orders,1)
gt_bboxes = gt_bboxes.flatten(2).view(num_gts*num_orders,-1).unsqueeze(0).repeat(bbox_pred.size(0),1,1)
# import pdb;pdb.set_trace()
bbox_cost = smooth_l1_loss(bbox_pred, gt_bboxes, reduction='none').sum(-1)
# bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
return bbox_cost * self.weight
@MATCH_COST.register_module()
class PtsL1Cost(object):
"""OrderedPtsL1Cost.
Args:
weight (int | float, optional): loss_weight
"""
def __init__(self, weight=1.):
self.weight = weight
def __call__(self, bbox_pred, gt_bboxes):
"""
Args:
bbox_pred (Tensor): Predicted boxes with normalized coordinates
(x, y), which are all in range [0, 1]. Shape
[num_query, num_pts, 2].
gt_bboxes (Tensor): Ground truth boxes with normalized
coordinates (x,y).
Shape [num_gt, num_ordered, num_pts, 2].
Returns:
torch.Tensor: bbox_cost value with weight
"""
num_gts, num_pts, num_coords = gt_bboxes.shape
# import pdb;pdb.set_trace()
bbox_pred = bbox_pred.view(bbox_pred.size(0),-1)
gt_bboxes = gt_bboxes.view(num_gts,-1)
bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
return bbox_cost * self.weight
@MATCH_COST.register_module()
class OrderedPtsL1Cost(object):
"""OrderedPtsL1Cost.
Args:
weight (int | float, optional): loss_weight
"""
def __init__(self, weight=1.):
self.weight = weight
def __call__(self, bbox_pred, gt_bboxes):
"""
Args:
bbox_pred (Tensor): Predicted boxes with normalized coordinates
(x, y), which are all in range [0, 1]. Shape
[num_query, num_pts, 2].
gt_bboxes (Tensor): Ground truth boxes with normalized
coordinates (x,y).
Shape [num_gt, num_ordered, num_pts, 2].
Returns:
torch.Tensor: bbox_cost value with weight
"""
num_gts, num_orders, num_pts, num_coords = gt_bboxes.shape
# import pdb;pdb.set_trace()
bbox_pred = bbox_pred.view(bbox_pred.size(0),-1)
gt_bboxes = gt_bboxes.flatten(2).view(num_gts*num_orders,-1)
#bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
bbox_cost = (bbox_pred[:,None,:] - gt_bboxes[None,:,:]).abs().sum(dim=-1)
return bbox_cost * self.weight
@MATCH_COST.register_module()
class MyChamferDistanceCost:
def __init__(self, loss_src_weight=1., loss_dst_weight=1.):
# assert mode in ['smooth_l1', 'l1', 'l2']
# self.mode = mode
self.loss_src_weight = loss_src_weight
self.loss_dst_weight = loss_dst_weight
def __call__(self, src, dst,src_weight=1.0,dst_weight=1.0,):
"""
pred_pts (Tensor): normed coordinate(x,y), shape (num_q, num_pts_M, 2)
gt_pts (Tensor): normed coordinate(x,y), shape (num_gt, num_pts_N, 2)
"""
# criterion_mode = self.mode
# if criterion_mode == 'smooth_l1':
# criterion = smooth_l1_loss
# elif criterion_mode == 'l1':
# criterion = l1_loss
# elif criterion_mode == 'l2':
# criterion = mse_loss
# else:
# raise NotImplementedError
# import pdb;pdb.set_trace()
src_expand = src.unsqueeze(1).repeat(1,dst.shape[0],1,1)
dst_expand = dst.unsqueeze(0).repeat(src.shape[0],1,1,1)
# src_expand = src.unsqueeze(2).unsqueeze(1).repeat(1,dst.shape[0], 1, dst.shape[1], 1)
# dst_expand = dst.unsqueeze(1).unsqueeze(0).repeat(src.shape[0],1, src.shape[1], 1, 1)
distance = torch.cdist(src_expand, dst_expand)
src2dst_distance = torch.min(distance, dim=3)[0] # (num_q, num_gt, num_pts_N)
dst2src_distance = torch.min(distance, dim=2)[0] # (num_q, num_gt, num_pts_M)
loss_src = (src2dst_distance * src_weight).mean(-1)
loss_dst = (dst2src_distance * dst_weight).mean(-1)
loss = loss_src*self.loss_src_weight + loss_dst * self.loss_dst_weight
return loss
@mmcv.jit(derivate=True, coderize=True)
def chamfer_distance(src,
dst,
src_weight=1.0,
dst_weight=1.0,
# criterion_mode='l1',
reduction='mean',
avg_factor=None):
"""Calculate Chamfer Distance of two sets.
Args:
src (torch.Tensor): Source set with shape [B, N, C] to
calculate Chamfer Distance.
dst (torch.Tensor): Destination set with shape [B, M, C] to
calculate Chamfer Distance.
src_weight (torch.Tensor or float): Weight of source loss.
dst_weight (torch.Tensor or float): Weight of destination loss.
criterion_mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2.
reduction (str): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Returns:
tuple: Source and Destination loss with the corresponding indices.
- loss_src (torch.Tensor): The min distance \
from source to destination.
- loss_dst (torch.Tensor): The min distance \
from destination to source.
- indices1 (torch.Tensor): Index the min distance point \
for each point in source to destination.
- indices2 (torch.Tensor): Index the min distance point \
for each point in destination to source.
"""
# if criterion_mode == 'smooth_l1':
# criterion = smooth_l1_loss
# elif criterion_mode == 'l1':
# criterion = l1_loss
# elif criterion_mode == 'l2':
# criterion = mse_loss
# else:
# raise NotImplementedError
# src_expand = src.unsqueeze(2).repeat(1, 1, dst.shape[1], 1)
# dst_expand = dst.unsqueeze(1).repeat(1, src.shape[1], 1, 1)
# import pdb;pdb.set_trace()
distance = torch.cdist(src, dst)
src2dst_distance, indices1 = torch.min(distance, dim=2) # (B,N)
dst2src_distance, indices2 = torch.min(distance, dim=1) # (B,M)
# import pdb;pdb.set_trace()
#TODO this may be wrong for misaligned src_weight, now[N,fixed_num]
# should be [N], then view
loss_src = (src2dst_distance * src_weight)
loss_dst = (dst2src_distance * dst_weight)
if avg_factor is None:
reduction_enum = F._Reduction.get_enum(reduction)
if reduction_enum == 0:
raise ValueError('MyCDLoss can not be used with reduction=`none`')
elif reduction_enum == 1:
loss_src = loss_src.mean(-1).mean()
loss_dst = loss_dst.mean(-1).mean()
elif reduction_enum == 2:
loss_src = loss_src.mean(-1).sum()
loss_dst = loss_dst.mean(-1).sum()
else:
raise NotImplementedError
else:
if reduction == 'mean':
eps = torch.finfo(torch.float32).eps
loss_src = loss_src.mean(-1).sum() / (avg_factor + eps)
loss_dst = loss_dst.mean(-1).sum() / (avg_factor + eps)
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss_src, loss_dst, indices1, indices2
@LOSSES.register_module()
class MyChamferDistance(nn.Module):
"""Calculate Chamfer Distance of two sets.
Args:
mode (str): Criterion mode to calculate distance.
The valid modes are smooth_l1, l1 or l2.
reduction (str): Method to reduce losses.
The valid reduction method are none, sum or mean.
loss_src_weight (float): Weight of loss_source.
loss_dst_weight (float): Weight of loss_target.
"""
def __init__(self,
# mode='l1',
reduction='mean',
loss_src_weight=1.0,
loss_dst_weight=1.0):
super(MyChamferDistance, self).__init__()
# assert mode in ['smooth_l1', 'l1', 'l2']
assert reduction in ['none', 'sum', 'mean']
# self.mode = mode
self.reduction = reduction
self.loss_src_weight = loss_src_weight
self.loss_dst_weight = loss_dst_weight
def forward(self,
source,
target,
src_weight=1.0,
dst_weight=1.0,
avg_factor=None,
reduction_override=None,
return_indices=False,
**kwargs):
"""Forward function of loss calculation.
Args:
source (torch.Tensor): Source set with shape [B, N, C] to
calculate Chamfer Distance.
target (torch.Tensor): Destination set with shape [B, M, C] to
calculate Chamfer Distance.
src_weight (torch.Tensor | float, optional):
Weight of source loss. Defaults to 1.0.
dst_weight (torch.Tensor | float, optional):
Weight of destination loss. Defaults to 1.0.
reduction_override (str, optional): Method to reduce losses.
The valid reduction method are 'none', 'sum' or 'mean'.
Defaults to None.
return_indices (bool, optional): Whether to return indices.
Defaults to False.
Returns:
tuple[torch.Tensor]: If ``return_indices=True``, return losses of \
source and target with their corresponding indices in the \
order of ``(loss_source, loss_target, indices1, indices2)``. \
If ``return_indices=False``, return \
``(loss_source, loss_target)``.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_source, loss_target, indices1, indices2 = chamfer_distance(
source, target, src_weight, dst_weight, reduction,
avg_factor=avg_factor)
loss_source *= self.loss_src_weight
loss_target *= self.loss_dst_weight
loss_pts = loss_source + loss_target
if return_indices:
return loss_pts, indices1, indices2
else:
return loss_pts
import torch
import torch.nn as nn
from mmdet.models.builder import LOSSES
import torch.nn.functional as F
from mmdet.models.losses import FocalLoss, weight_reduce_loss
def py_sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module(force=True)
class SimpleLoss_v1(nn.Module):
def __init__(self, pos_weight, loss_weight):
super(SimpleLoss_v1, self).__init__()
# self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([pos_weight]))
# self.loss_fn = torch.nn.CrossEntroyLoss(reduction="none")
self.loss_weight = loss_weight
def forward(self, ypred, ytgt):
bs, pred_class_num, bev_h, bev_w = ypred.shape
ypred = ypred.permute(0, 2, 3, 1).reshape(bs*bev_h*bev_w, pred_class_num).contiguous()
ytgt = ytgt.view(-1)
ytgt = F.one_hot(ytgt.long(), num_classes=pred_class_num+1).view(-1, pred_class_num+1)[:, 1:]
fg_mask = torch.max(ytgt, dim=1).values > 0.0
ypred = ypred[fg_mask]
ytgt = ytgt[fg_mask]
loss = F.binary_cross_entropy_with_logits(ypred, ytgt.float(), reduction='none',).sum() / max(1.0, fg_mask.sum())
return loss*self.loss_weight
@LOSSES.register_module()
class SimpleLoss(torch.nn.Module):
def __init__(self, pos_weight, loss_weight):
super(SimpleLoss, self).__init__()
self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([pos_weight]))
self.loss_weight = loss_weight
def forward(self, ypred, ytgt):
# import ipdb;ipdb.set_trace()
loss = self.loss_fn(ypred, ytgt)
return loss*self.loss_weight
@LOSSES.register_module()
class MaskFocalLoss(FocalLoss):
def __init__(self,**kwargs):
super(MaskFocalLoss, self).__init__(**kwargs)
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if not self.use_sigmoid:
raise NotImplementedError
num_classes = pred.size(1)
loss = 0
for index in range(num_classes):
loss += self.loss_weight * py_sigmoid_focal_loss(
pred[:,index],
target[:,index],
weight,
gamma=self.gamma,
alpha=self.alpha,
reduction=reduction,
avg_factor=avg_factor)
# import ipdb; ipdb.set_trace()
loss /= num_classes
return loss
\ No newline at end of file
from .transformer import MapTRPerceptionTransformer
from .decoder import MapTRDecoder, DecoupledDetrTransformerDecoderLayer
from .geometry_kernel_attention import GeometrySptialCrossAttention, GeometryKernelAttention
from .builder import build_fuser
from .encoder import LSSTransform
\ No newline at end of file
import torch.nn as nn
from mmcv.utils import Registry, build_from_cfg
FUSERS = Registry("fusers")
def build_fuser(cfg):
return FUSERS.build(cfg)
\ No newline at end of file
import torch
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER,
POSITIONAL_ENCODING,
TRANSFORMER_LAYER_SEQUENCE)
from mmdet.models.utils.transformer import inverse_sigmoid
from mmcv.cnn.bricks.transformer import TransformerLayerSequence, BaseTransformerLayer
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class MapTRDecoder(TransformerLayerSequence):
"""Implements the decoder in DETR3D transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default:
`LN`.
"""
def __init__(self, *args, return_intermediate=False, **kwargs):
super(MapTRDecoder, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
self.fp16_enabled = False
#@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self,
query,
*args,
reference_points=None,
reg_branches=None,
key_padding_mask=None,
**kwargs):
"""Forward function for `Detr3DTransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output = query
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
reference_points_input = reference_points[..., :2].unsqueeze(
2) # BS NUM_QUERY NUM_LEVEL 2
output = layer(
output,
*args,
reference_points=reference_points_input,
key_padding_mask=key_padding_mask,
**kwargs)
output = output.permute(1, 0, 2)
if reg_branches is not None:
tmp = reg_branches[lid](output)
# assert reference_points.shape[-1] == 2
new_reference_points = torch.zeros_like(reference_points)
new_reference_points = tmp + inverse_sigmoid(reference_points)
# new_reference_points[..., 2:3] = tmp[
# ..., 4:5] + inverse_sigmoid(reference_points[..., 2:3])
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
output = output.permute(1, 0, 2)
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(
intermediate_reference_points)
return output, reference_points
@TRANSFORMER_LAYER.register_module()
class DecoupledDetrTransformerDecoderLayer(BaseTransformerLayer):
"""Implements decoder layer in DETR transformer.
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Configs for self_attention or cross_attention, the order
should be consistent with it in `operation_order`. If it is
a dict, it would be expand to the number of attention in
`operation_order`.
feedforward_channels (int): The hidden dimension for FFNs.
ffn_dropout (float): Probability of an element to be zeroed
in ffn. Default 0.0.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Default:None
act_cfg (dict): The activation config for FFNs. Default: `LN`
norm_cfg (dict): Config dict for normalization layer.
Default: `LN`.
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default:2.
"""
def __init__(self,
attn_cfgs,
feedforward_channels,
num_vec=50,
num_pts_per_vec=20,
ffn_dropout=0.0,
operation_order=None,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'),
ffn_num_fcs=2,
**kwargs):
super(DecoupledDetrTransformerDecoderLayer, self).__init__(
attn_cfgs=attn_cfgs,
feedforward_channels=feedforward_channels,
ffn_dropout=ffn_dropout,
operation_order=operation_order,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
ffn_num_fcs=ffn_num_fcs,
**kwargs)
assert len(operation_order) == 8
assert set(operation_order) == set(
['self_attn', 'norm', 'cross_attn', 'ffn'])
self.num_vec = num_vec
self.num_pts_per_vec = num_pts_per_vec
def forward(self,
query,
key=None,
value=None,
query_pos=None,
key_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
**kwargs):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index = 0
attn_index = 0
ffn_index = 0
identity = query
if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor):
attn_masks = [
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
]
warnings.warn(f'Use same attn_mask in all attentions in '
f'{self.__class__.__name__} ')
else:
assert len(attn_masks) == self.num_attn, f'The length of ' \
f'attn_masks {len(attn_masks)} must be equal ' \
f'to the number of attention in ' \
f'operation_order {self.num_attn}'
#
num_vec = kwargs['num_vec']
num_pts_per_vec = kwargs['num_pts_per_vec']
for layer in self.operation_order:
if layer == 'self_attn':
# import ipdb;ipdb.set_trace()
if attn_index == 0:
n_pts, n_batch, n_dim = query.shape
query = query.view(num_vec, num_pts_per_vec,n_batch,n_dim).flatten(1,2)
query_pos = query_pos.view(num_vec, num_pts_per_vec,n_batch,n_dim).flatten(1,2)
temp_key = temp_value = query
query = self.attentions[attn_index](
query,
temp_key,
temp_value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=kwargs['self_attn_mask'],
key_padding_mask=query_key_padding_mask,
**kwargs)
# import ipdb;ipdb.set_trace()
query = query.view(num_vec, num_pts_per_vec, n_batch, n_dim).flatten(0,1)
query_pos = query_pos.view(num_vec, num_pts_per_vec, n_batch, n_dim).flatten(0,1)
attn_index += 1
identity = query
else:
# import ipdb;ipdb.set_trace()
n_pts, n_batch, n_dim = query.shape
query = query.view(num_vec, num_pts_per_vec,n_batch,n_dim).permute(1,0,2,3).contiguous().flatten(1,2)
query_pos = query_pos.view(num_vec, num_pts_per_vec,n_batch,n_dim).permute(1,0,2,3).contiguous().flatten(1,2)
temp_key = temp_value = query
query = self.attentions[attn_index](
query,
temp_key,
temp_value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
**kwargs)
# import ipdb;ipdb.set_trace()
query = query.view(num_pts_per_vec, num_vec, n_batch, n_dim).permute(1,0,2,3).contiguous().flatten(0,1)
query_pos = query_pos.view(num_pts_per_vec, num_vec, n_batch, n_dim).permute(1,0,2,3).contiguous().flatten(0,1)
attn_index += 1
identity = query
elif layer == 'norm':
query = self.norms[norm_index](query)
norm_index += 1
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask,
**kwargs)
attn_index += 1
identity = query
elif layer == 'ffn':
query = self.ffns[ffn_index](
query, identity if self.pre_norm else None)
ffn_index += 1
return query
import torch
import numpy as np
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
import torch.nn as nn
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmdet3d.ops import bev_pool
from mmdet3d.ops.bev_pool_v2.bev_pool import bev_pool_v2
from mmcv.runner import force_fp32, auto_fp16
from torch.cuda.amp.autocast_mode import autocast
from mmcv.cnn import build_conv_layer
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
import torch.nn.functional as F
from projects.mmdet3d_plugin.bevformer.modules.encoder import BEVFormerEncoder
torch.set_float32_matmul_precision('high')
def gen_dx_bx(xbound, ybound, zbound):
dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]])
bx = torch.Tensor([row[0] + row[2] / 2.0 for row in [xbound, ybound, zbound]])
nx = torch.Tensor(
[int((row[1] - row[0]) / row[2]) for row in [xbound, ybound, zbound]]
)
return dx, bx, nx
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class BaseTransform(BaseModule):
def __init__(
self,
in_channels,
out_channels,
feat_down_sample,
pc_range,
voxel_size,
dbound,
):
super(BaseTransform, self).__init__()
self.in_channels = in_channels
self.feat_down_sample = feat_down_sample
# self.image_size = image_size
# self.feature_size = feature_size
self.xbound = [pc_range[0],pc_range[3], voxel_size[0]]
self.ybound = [pc_range[1],pc_range[4], voxel_size[1]]
self.zbound = [pc_range[2],pc_range[5], voxel_size[2]]
self.dbound = dbound
dx, bx, nx = gen_dx_bx(self.xbound, self.ybound, self.zbound)
self.dx = nn.Parameter(dx, requires_grad=False)
self.bx = nn.Parameter(bx, requires_grad=False)
self.nx = nn.Parameter(nx, requires_grad=False)
self.C = out_channels
self.frustum = None
self.D = int((dbound[1] - dbound[0]) / dbound[2])
# self.frustum = self.create_frustum()
# self.D = self.frustum.shape[0]
self.fp16_enabled = False
@force_fp32()
def create_frustum(self,fH,fW,img_metas):
# iH, iW = self.image_size
# fH, fW = self.feature_size
iH = img_metas[0]['img_shape'][0][0]
iW = img_metas[0]['img_shape'][0][1]
assert iH // self.feat_down_sample == fH
# import pdb;pdb.set_trace()
ds = (
torch.arange(*self.dbound, dtype=torch.float)
.view(-1, 1, 1)
.expand(-1, fH, fW)
)
D, _, _ = ds.shape
xs = (
torch.linspace(0, iW - 1, fW, dtype=torch.float)
.view(1, 1, fW)
.expand(D, fH, fW)
)
ys = (
torch.linspace(0, iH - 1, fH, dtype=torch.float)
.view(1, fH, 1)
.expand(D, fH, fW)
)
frustum = torch.stack((xs, ys, ds), -1)
# return nn.Parameter(frustum, requires_grad=False)
return frustum
#@torch.compile(mode="max-autotune-no-cudagraphs")
def matmul_1(self, x, y, trans):
B, N, _ = trans.shape
points = torch.matmul(x.view(B, N, 1, 1, 1, 3, 3), y.unsqueeze(-1))
points = torch.cat(
(
points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3],
points[:, :, :, :, :, 2:3],
),
5,
)
return points
#@torch.compile(mode="max-autotune-no-cudagraphs")
def matmul_2(self, x, y, trans, lidar2ego_trans, points):
B, N, _ = trans.shape
combine = torch.matmul(x, y)
points = torch.matmul(combine.view(B, N, 1, 1, 1, 3, 3), points).squeeze(-1)
points += trans.view(B, N, 1, 1, 1, 3)
points -= lidar2ego_trans.view(B, 1, 1, 1, 1, 3)
return points
#@torch.compile(mode="max-autotune-no-cudagraphs")
def matmul_3(self, x, y, trans):
B, N, _ = trans.shape
points = torch.matmul(x.view(B, 1, 1, 1, 1, 3, 3), y.unsqueeze(-1)).squeeze(-1)
return points
@force_fp32()
def get_geometry_v1(
self,
fH,
fW,
rots,
trans,
intrins,
post_rots,
post_trans,
lidar2ego_rots,
lidar2ego_trans,
img_metas,
**kwargs,
):
B, N, _ = trans.shape
device = trans.device
if self.frustum == None:
self.frustum = self.create_frustum(fH,fW,img_metas)
self.frustum = self.frustum.to(device)
# self.D = self.frustum.shape[0]
# undo post-transformation
# B x N x D x H x W x 3
points = self.frustum - post_trans.view(B, N, 1, 1, 1, 3)
post_rots = torch.inverse(post_rots)
points = self.matmul_1(post_rots, points, trans)
# points = torch.matmul(post_rots.view(B, N, 1, 1, 1, 3, 3), points.unsqueeze(-1))
# # cam_to_ego
# points = torch.cat(
# (
# points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3],
# points[:, :, :, :, :, 2:3],
# ),
# 5,
# )
intrins = torch.inverse(intrins)
# combine = torch.matmul(rots, intrins)
# points = torch.matmul(combine.view(B, N, 1, 1, 1, 3, 3), points).squeeze(-1)
# points += trans.view(B, N, 1, 1, 1, 3)
# # ego_to_lidar
# points -= lidar2ego_trans.view(B, 1, 1, 1, 1, 3)
points = self.matmul_2(rots, intrins, trans, lidar2ego_trans, points)
lidar2ego_rots = torch.inverse(lidar2ego_rots)
points = self.matmul_3(lidar2ego_rots, points, trans)
if "extra_rots" in kwargs:
extra_rots = kwargs["extra_rots"]
points = torch.matmul(extra_rots.view(B, 1, 1, 1, 1, 3, 3).repeat(1, N, 1, 1, 1, 1, 1), points.unsqueeze(-1)).squeeze(-1)
if "extra_trans" in kwargs:
extra_trans = kwargs["extra_trans"]
points += extra_trans.view(B, 1, 1, 1, 1, 3).repeat(1, N, 1, 1, 1, 1)
return points
@force_fp32()
def get_geometry(
self,
fH,
fW,
lidar2img,
img_metas,
):
B, N, _, _ = lidar2img.shape
device = lidar2img.device
# import pdb;pdb.set_trace()
if self.frustum == None:
self.frustum = self.create_frustum(fH,fW,img_metas)
self.frustum = self.frustum.to(device)
# self.D = self.frustum.shape[0]
points = self.frustum.view(1,1,self.D, fH, fW, 3) \
.repeat(B,N,1,1,1,1)
lidar2img = lidar2img.view(B,N,1,1,1,4,4)
# img2lidar = torch.inverse(lidar2img)
points = torch.cat(
(points, torch.ones_like(points[..., :1])), -1)
points = torch.linalg.solve(lidar2img.to(torch.float32),
points.unsqueeze(-1).to(torch.float32)).squeeze(-1)
# points = torch.matmul(img2lidar.to(torch.float32),
# points.unsqueeze(-1).to(torch.float32)).squeeze(-1)
# import pdb;pdb.set_trace()
eps = 1e-5
points = points[..., 0:3] / torch.maximum(
points[..., 3:4], torch.ones_like(points[..., 3:4]) * eps)
return points
def get_cam_feats(self, x):
raise NotImplementedError
def get_mlp_input(self, sensor2ego, intrin, post_rot, post_tran, bda):
raise NotImplementedError
@force_fp32()
def bev_pool(self, geom_feats, x):
B, N, D, H, W, C = x.shape
Nprime = B * N * D * H * W
# flatten x
x = x.reshape(Nprime, C)
# flatten indices
geom_feats = ((geom_feats - (self.bx - self.dx / 2.0)) / self.dx).long()
geom_feats = geom_feats.view(Nprime, 3)
batch_ix = torch.cat(
[
torch.full([Nprime // B, 1], ix, device=x.device, dtype=torch.long)
for ix in range(B)
]
)
geom_feats = torch.cat((geom_feats, batch_ix), 1)
# filter out points that are outside box
kept = (
(geom_feats[:, 0] >= 0)
& (geom_feats[:, 0] < self.nx[0])
& (geom_feats[:, 1] >= 0)
& (geom_feats[:, 1] < self.nx[1])
& (geom_feats[:, 2] >= 0)
& (geom_feats[:, 2] < self.nx[2])
)
x = x[kept]
geom_feats = geom_feats[kept]
# idx = torch.where(kept)[0]
# x = x.index_select(0, idx)
# geom_feats = geom_feats.index_select(0, idx)
x = bev_pool(x, geom_feats, B, self.nx[2], self.nx[0], self.nx[1])
# # collapse Z
# x = x.permute(0, 4, 1, 2, 3).contiguous()
# final = torch.cat(x.unbind(dim=2), 1)
return x
def stack_metas(self, metas, key, device, dtype):
tensors = []
for meta in metas:
val = meta[key]
if isinstance(val, np.ndarray):
val = torch.from_numpy(val)
elif isinstance(val, list):
val = torch.stack([torch.from_numpy(v) if isinstance(v, np.ndarray) else v for v in val], dim=0)
tensors.append(val)
return torch.stack(tensors, dim=0).to(device=device, dtype=dtype)
#@torch.compile(mode="max-autotune-no-cudagraphs")
def extract_metas(self, images, img_metas):
device = images.device
dtype = images.dtype
lidar2img = self.stack_metas(img_metas, 'lidar2img', device, dtype)
camera2ego = self.stack_metas(img_metas, 'camera2ego', device, dtype)
camera_intrinsics = self.stack_metas(img_metas, 'camera_intrinsics', device, dtype)
img_aug_matrix = self.stack_metas(img_metas, 'img_aug_matrix', device, dtype)
lidar2ego = self.stack_metas(img_metas, 'lidar2ego', device, dtype)
rots = camera2ego[..., :3, :3]
trans = camera2ego[..., :3, 3]
intrins = camera_intrinsics[..., :3, :3]
post_rots = img_aug_matrix[..., :3, :3]
post_trans = img_aug_matrix[..., :3, 3]
lidar2ego_rots = lidar2ego[..., :3, :3]
lidar2ego_trans = lidar2ego[..., :3, 3]
return rots, trans, intrins, post_rots, post_trans, lidar2ego_rots, lidar2ego_trans, camera2ego, camera_intrinsics
@force_fp32()
def forward(
self,
images,
img_metas
):
B, N, C, fH, fW = images.shape
rots, trans, intrins, post_rots, post_trans, lidar2ego_rots, lidar2ego_trans, camera2ego, camera_intrinsics = self.extract_metas(images, img_metas)
geom = self.get_geometry_v1(
fH,
fW,
rots,
trans,
intrins,
post_rots,
post_trans,
lidar2ego_rots,
lidar2ego_trans,
img_metas
)
mlp_input = self.get_mlp_input(camera2ego, camera_intrinsics, post_rots, post_trans)
x, depth = self.get_cam_feats(images, mlp_input)
x = self.bev_pool(geom, x)
# x = x.permute(0,1,3,2).contiguous()
return x, depth
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class BaseTransformV2(BaseModule):
def __init__(
self,
input_size,
in_channels,
out_channels,
feat_down_sample,
pc_range,
voxel_size,
dbound,
sid=False,
):
super(BaseTransformV2, self).__init__()
self.in_channels = in_channels
self.feat_down_sample = feat_down_sample
# self.image_size = image_size
# self.feature_size = feature_size
xbound = [pc_range[0],pc_range[3], voxel_size[0]]
ybound = [pc_range[1],pc_range[4], voxel_size[1]]
zbound = [pc_range[2],pc_range[5], voxel_size[2]]
grid_config = [xbound, ybound, zbound]
self.create_grid_infos(*grid_config)
self.dbound = dbound
self.sid = sid
self.frustum = self.create_frustum(dbound,
input_size, feat_down_sample)
self.C = out_channels
self.D = round((dbound[1] - dbound[0]) / dbound[2])
self.fp16_enabled = False
def create_grid_infos(self, x, y, z, **kwargs):
"""Generate the grid information including the lower bound, interval,
and size.
Args:
x (tuple(float)): Config of grid alone x axis in format of
(lower_bound, upper_bound, interval).
y (tuple(float)): Config of grid alone y axis in format of
(lower_bound, upper_bound, interval).
z (tuple(float)): Config of grid alone z axis in format of
(lower_bound, upper_bound, interval).
**kwargs: Container for other potential parameters
"""
self.grid_lower_bound = torch.Tensor([cfg[0] for cfg in [x, y, z]])
self.grid_interval = torch.Tensor([cfg[2] for cfg in [x, y, z]])
self.grid_size = torch.Tensor([(cfg[1] - cfg[0]) / cfg[2]
for cfg in [x, y, z]])
# @force_fp32()
def create_frustum(self, depth_cfg, input_size, downsample):
"""Generate the frustum template for each image.
Args:
depth_cfg (tuple(float)): Config of grid alone depth axis in format
of (lower_bound, upper_bound, interval).
input_size (tuple(int)): Size of input images in format of (height,
width).
downsample (int): Down sample scale factor from the input size to
the feature size.
"""
H_in, W_in = input_size
H_feat, W_feat = H_in // downsample, W_in // downsample
d = torch.arange(*depth_cfg, dtype=torch.float)\
.view(-1, 1, 1).expand(-1, H_feat, W_feat)
self.D = d.shape[0]
if self.sid:
d_sid = torch.arange(self.D).float()
depth_cfg_t = torch.tensor(depth_cfg).float()
d_sid = torch.exp(torch.log(depth_cfg_t[0]) + d_sid / (self.D-1) *
torch.log((depth_cfg_t[1]-1) / depth_cfg_t[0]))
d = d_sid.view(-1, 1, 1).expand(-1, H_feat, W_feat)
x = torch.linspace(0, W_in - 1, W_feat, dtype=torch.float)\
.view(1, 1, W_feat).expand(self.D, H_feat, W_feat)
y = torch.linspace(0, H_in - 1, H_feat, dtype=torch.float)\
.view(1, H_feat, 1).expand(self.D, H_feat, W_feat)
# D x H x W x 3
return torch.stack((x, y, d), -1)
def get_lidar_coor(self,
fH,
fW,
rots,
trans,
intrins,
post_rots,
post_trans,
lidar2ego_rots,
lidar2ego_trans,
img_metas):
B, N, _, _ = sensor2ego.shape
# post-transformation
# B x N x D x H x W x 3
points = self.frustum.to(sensor2ego) - post_trans.view(B, N, 1, 1, 1, 3)
points = torch.inverse(post_rots).view(B, N, 1, 1, 1, 3, 3)\
.matmul(points.unsqueeze(-1))
# cam_to_ego
points = torch.cat(
(points[..., :2, :] * points[..., 2:3, :], points[..., 2:3, :]), 5)
combine = rots.matmul(torch.inverse(intrins))
points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)
points += trans.view(B, N, 1, 1, 1, 3)
# ego_to_lidar
points -= lidar2ego_trans.view(B, 1, 1, 1, 1, 3)
points = (
torch.inverse(lidar2ego_rots)
.view(B, 1, 1, 1, 1, 3, 3)
.matmul(points.unsqueeze(-1))
.squeeze(-1)
)
return points
@force_fp32()
def get_geometry_v1(
self,
fH,
fW,
rots,
trans,
intrins,
post_rots,
post_trans,
lidar2ego_rots,
lidar2ego_trans,
img_metas,
**kwargs,
):
B, N, _ = trans.shape
device = trans.device
# if self.frustum == None:
# self.frustum = self.create_frustum(fH,fW,img_metas)
# self.frustum = self.frustum.to(device)
# # self.D = self.frustum.shape[0]
# undo post-transformation
# B x N x D x H x W x 3
points = self.frustum.to(device)- post_trans.view(B, N, 1, 1, 1, 3)
points = (
torch.inverse(post_rots)
.view(B, N, 1, 1, 1, 3, 3)
.matmul(points.unsqueeze(-1))
)
# cam_to_ego
points = torch.cat(
(
points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3],
points[:, :, :, :, :, 2:3],
),
5,
)
combine = rots.matmul(torch.inverse(intrins))
points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)
points += trans.view(B, N, 1, 1, 1, 3)
# ego_to_lidar
points -= lidar2ego_trans.view(B, 1, 1, 1, 1, 3)
points = (
torch.inverse(lidar2ego_rots)
.view(B, 1, 1, 1, 1, 3, 3)
.matmul(points.unsqueeze(-1))
.squeeze(-1)
)
if "extra_rots" in kwargs:
extra_rots = kwargs["extra_rots"]
points = (
extra_rots.view(B, 1, 1, 1, 1, 3, 3)
.repeat(1, N, 1, 1, 1, 1, 1)
.matmul(points.unsqueeze(-1))
.squeeze(-1)
)
if "extra_trans" in kwargs:
extra_trans = kwargs["extra_trans"]
points += extra_trans.view(B, 1, 1, 1, 1, 3).repeat(1, N, 1, 1, 1, 1)
return points
@force_fp32()
def get_geometry(
self,
fH,
fW,
lidar2img,
img_metas,
):
B, N, _, _ = lidar2img.shape
device = lidar2img.device
if self.frustum == None:
self.frustum = self.create_frustum(fH,fW,img_metas)
self.frustum = self.frustum.to(device)
# self.D = self.frustum.shape[0]
points = self.frustum.view(1,1,self.D, fH, fW, 3) \
.repeat(B,N,1,1,1,1)
lidar2img = lidar2img.view(B,N,1,1,1,4,4)
# img2lidar = torch.inverse(lidar2img)
points = torch.cat(
(points, torch.ones_like(points[..., :1])), -1)
points = torch.linalg.solve(lidar2img.to(torch.float32),
points.unsqueeze(-1).to(torch.float32)).squeeze(-1)
# points = torch.matmul(img2lidar.to(torch.float32),
# points.unsqueeze(-1).to(torch.float32)).squeeze(-1)
eps = 1e-5
points = points[..., 0:3] / torch.maximum(
points[..., 3:4], torch.ones_like(points[..., 3:4]) * eps)
return points
def get_cam_feats(self, x):
raise NotImplementedError
def get_mlp_input(self, sensor2ego, intrin, post_rot, post_tran, bda):
raise NotImplementedError
def voxel_pooling_prepare_v2(self, coor):
"""Data preparation for voxel pooling.
Args:
coor (torch.tensor): Coordinate of points in the lidar space in
shape (B, N, D, H, W, 3).
Returns:
tuple[torch.tensor]: Rank of the voxel that a point is belong to
in shape (N_Points); Reserved index of points in the depth
space in shape (N_Points). Reserved index of points in the
feature space in shape (N_Points).
"""
B, N, D, H, W, _ = coor.shape
num_points = B * N * D * H * W
# record the index of selected points for acceleration purpose
ranks_depth = torch.range(
0, num_points - 1, dtype=torch.int, device=coor.device)
ranks_feat = torch.range(
0, num_points // D - 1, dtype=torch.int, device=coor.device)
ranks_feat = ranks_feat.reshape(B, N, 1, H, W)
ranks_feat = ranks_feat.expand(B, N, D, H, W).flatten()
# convert coordinate into the voxel space
coor = ((coor - self.grid_lower_bound.to(coor)) /
self.grid_interval.to(coor))
coor = coor.long().view(num_points, 3)
batch_idx = torch.range(0, B - 1).reshape(B, 1). \
expand(B, num_points // B).reshape(num_points, 1).to(coor)
coor = torch.cat((coor, batch_idx), 1)
# filter out points that are outside box
kept = (coor[:, 0] >= 0) & (coor[:, 0] < self.grid_size[0]) & \
(coor[:, 1] >= 0) & (coor[:, 1] < self.grid_size[1]) & \
(coor[:, 2] >= 0) & (coor[:, 2] < self.grid_size[2])
if len(kept) == 0:
return None, None, None, None, None
coor, ranks_depth, ranks_feat = \
coor[kept], ranks_depth[kept], ranks_feat[kept]
# get tensors from the same voxel next to each other
ranks_bev = coor[:, 3] * (
self.grid_size[2] * self.grid_size[1] * self.grid_size[0])
ranks_bev += coor[:, 2] * (self.grid_size[1] * self.grid_size[0])
ranks_bev += coor[:, 1] * self.grid_size[0] + coor[:, 0]
order = ranks_bev.argsort()
ranks_bev, ranks_depth, ranks_feat = \
ranks_bev[order], ranks_depth[order], ranks_feat[order]
kept = torch.ones(
ranks_bev.shape[0], device=ranks_bev.device, dtype=torch.bool)
kept[1:] = ranks_bev[1:] != ranks_bev[:-1]
interval_starts = torch.where(kept)[0].int()
if len(interval_starts) == 0:
return None, None, None, None, None
interval_lengths = torch.zeros_like(interval_starts)
interval_lengths[:-1] = interval_starts[1:] - interval_starts[:-1]
interval_lengths[-1] = ranks_bev.shape[0] - interval_starts[-1]
return ranks_bev.int().contiguous(), ranks_depth.int().contiguous(
), ranks_feat.int().contiguous(), interval_starts.int().contiguous(
), interval_lengths.int().contiguous()
@force_fp32()
def voxel_pooling_v2(self, coor, depth, feat):
ranks_bev, ranks_depth, ranks_feat, \
interval_starts, interval_lengths = \
self.voxel_pooling_prepare_v2(coor)
if ranks_feat is None:
print('warning ---> no points within the predefined '
'bev receptive field')
dummy = torch.zeros(size=[
feat.shape[0], feat.shape[2],
int(self.grid_size[2]),
int(self.grid_size[0]),
int(self.grid_size[1])
]).to(feat)
dummy = torch.cat(dummy.unbind(dim=2), 1)
return dummy
feat = feat.permute(0, 1, 3, 4, 2)
bev_feat_shape = (depth.shape[0], int(self.grid_size[2]),
int(self.grid_size[1]), int(self.grid_size[0]),
feat.shape[-1]) # (B, Z, Y, X, C)
bev_feat = bev_pool_v2(depth, feat, ranks_depth, ranks_feat, ranks_bev,
bev_feat_shape, interval_starts,
interval_lengths)
# collapse Z
# if self.collapse_z:
bev_feat = torch.cat(bev_feat.unbind(dim=2), 1)
return bev_feat
@force_fp32()
def bev_pool(self, geom_feats, x):
B, N, D, H, W, C = x.shape
Nprime = B * N * D * H * W
# flatten x
x = x.reshape(Nprime, C)
# flatten indices
geom_feats = ((geom_feats - (self.bx - self.dx / 2.0)) / self.dx).long()
geom_feats = geom_feats.view(Nprime, 3)
batch_ix = torch.cat(
[
torch.full([Nprime // B, 1], ix, device=x.device, dtype=torch.long)
for ix in range(B)
]
)
geom_feats = torch.cat((geom_feats, batch_ix), 1)
# filter out points that are outside box
kept = (
(geom_feats[:, 0] >= 0)
& (geom_feats[:, 0] < self.nx[0])
& (geom_feats[:, 1] >= 0)
& (geom_feats[:, 1] < self.nx[1])
& (geom_feats[:, 2] >= 0)
& (geom_feats[:, 2] < self.nx[2])
)
x = x[kept]
geom_feats = geom_feats[kept]
x = bev_pool(x, geom_feats, B, self.nx[2], self.nx[0], self.nx[1])
# collapse Z
final = torch.cat(x.unbind(dim=2), 1)
return final
@force_fp32()
def forward(
self,
images,
img_metas
):
B, N, C, fH, fW = images.shape
lidar2img = []
camera2ego = []
camera_intrinsics = []
img_aug_matrix = []
lidar2ego = []
for img_meta in img_metas:
lidar2img.append(img_meta['lidar2img'])
camera2ego.append(img_meta['camera2ego'])
camera_intrinsics.append(img_meta['camera_intrinsics'])
img_aug_matrix.append(img_meta['img_aug_matrix'])
lidar2ego.append(img_meta['lidar2ego'])
lidar2img = np.asarray(lidar2img)
lidar2img = images.new_tensor(lidar2img) # (B, N, 4, 4)
camera2ego = np.asarray(camera2ego)
camera2ego = images.new_tensor(camera2ego) # (B, N, 4, 4)
camera_intrinsics = np.asarray(camera_intrinsics)
camera_intrinsics = images.new_tensor(camera_intrinsics) # (B, N, 4, 4)
img_aug_matrix = np.asarray(img_aug_matrix)
img_aug_matrix = images.new_tensor(img_aug_matrix) # (B, N, 4, 4)
lidar2ego = np.asarray(lidar2ego)
lidar2ego = images.new_tensor(lidar2ego) # (B, N, 4, 4)
# lidar2cam = torch.linalg.solve(camera2ego, lidar2ego.view(B,1,4,4).repeat(1,N,1,1))
# lidar2oriimg = torch.matmul(camera_intrinsics,lidar2cam)
# mylidar2img = torch.matmul(img_aug_matrix,lidar2oriimg)
rots = camera2ego[..., :3, :3]
trans = camera2ego[..., :3, 3]
intrins = camera_intrinsics[..., :3, :3]
post_rots = img_aug_matrix[..., :3, :3]
post_trans = img_aug_matrix[..., :3, 3]
lidar2ego_rots = lidar2ego[..., :3, :3]
lidar2ego_trans = lidar2ego[..., :3, 3]
sensor_config = [fH, fW, rots, trans, intrins, post_rots, post_trans, lidar2ego_rots, lidar2ego_trans, img_metas]
# coor = self.get_lidar_coor(*sensor_config)
# # tmpgeom = self.get_geometry(
# # fH,
# # fW,
# # mylidar2img,
# # img_metas,
# # )
coor = self.get_geometry_v1(
fH,
fW,
rots,
trans,
intrins,
post_rots,
post_trans,
lidar2ego_rots,
lidar2ego_trans,
img_metas
)
mlp_input = self.get_mlp_input(camera2ego, camera_intrinsics, post_rots, post_trans)
tran_feat, depth = self.get_cam_feats(images, mlp_input)
bev_feat = self.voxel_pooling_v2(
coor, depth,
tran_feat)
# x = self.bev_pool(geom, x)
# import ipdb;ipdb.set_trace()
# bev_feat = bev_feat.permute(0,1,3,2).contiguous()
return bev_feat, depth
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.ReLU,
drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class SELayer(nn.Module):
def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid):
super().__init__()
self.conv_reduce = nn.Conv2d(channels, channels, 1, bias=True)
self.act1 = act_layer()
self.conv_expand = nn.Conv2d(channels, channels, 1, bias=True)
self.gate = gate_layer()
def forward(self, x, x_se):
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
return x * self.gate(x_se)
class DepthNet(nn.Module):
def __init__(self,
in_channels,
mid_channels,
context_channels,
depth_channels,
use_dcn=True,
use_aspp=True,
with_cp=False,
aspp_mid_channels=-1,
only_depth=False):
super(DepthNet, self).__init__()
self.reduce_conv = nn.Sequential(
nn.Conv2d(
in_channels, mid_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
self.only_depth = only_depth or context_channels == 0
if not self.only_depth:
self.context_conv = nn.Conv2d(
mid_channels, context_channels, kernel_size=1, stride=1, padding=0)
self.context_mlp = Mlp(22, mid_channels, mid_channels)
self.context_se = SELayer(mid_channels) # NOTE: add camera-aware
self.bn = nn.BatchNorm1d(22)
self.depth_mlp = Mlp(22, mid_channels, mid_channels)
self.depth_se = SELayer(mid_channels) # NOTE: add camera-aware
depth_conv_list = [
BasicBlock(mid_channels, mid_channels),
BasicBlock(mid_channels, mid_channels),
BasicBlock(mid_channels, mid_channels),
]
if use_aspp:
if aspp_mid_channels<0:
aspp_mid_channels = mid_channels
depth_conv_list.append(ASPP(mid_channels, aspp_mid_channels))
if use_dcn:
depth_conv_list.append(
build_conv_layer(
cfg=dict(
type='DCN',
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
padding=1,
groups=4,
im2col_step=128,
)))
depth_conv_list.append(
nn.Conv2d(
mid_channels,
depth_channels,
kernel_size=1,
stride=1,
padding=0))
self.depth_conv = nn.Sequential(*depth_conv_list)
self.with_cp = with_cp
def forward(self, x, mlp_input):
mlp_input = self.bn(mlp_input.reshape(-1, mlp_input.shape[-1]))
x = self.reduce_conv(x)
if not self.only_depth:
context_se = self.context_mlp(mlp_input)[..., None, None]
context = self.context_se(x, context_se)
context = self.context_conv(context)
depth_se = self.depth_mlp(mlp_input)[..., None, None]
depth = self.depth_se(x, depth_se)
if self.with_cp:
depth = checkpoint(self.depth_conv, depth)
else:
depth = self.depth_conv(depth)
if not self.only_depth:
return torch.cat([depth, context], dim=1)
else:
return depth
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class BEVFormerEncoderDepth(BEVFormerEncoder):
def __init__(self, *args, in_channels=256, out_channels=256, feat_down_sample=32, loss_depth_weight = 3.0,
depthnet_cfg=dict(),grid_config=None,**kwargs):
super(BEVFormerEncoderDepth, self).__init__(*args, **kwargs)
self.fp16_enabled = False
self.loss_depth_weight = loss_depth_weight
self.feat_down_sample = feat_down_sample
self.grid_config = grid_config
self.D = int((grid_config['depth'][1] - grid_config['depth'][0]) / grid_config['depth'][2])
self.depth_net = DepthNet(in_channels, in_channels,
0, self.D, **depthnet_cfg)
@auto_fp16()
def forward(self,
bev_query,
key,
value,
*args,
mlvl_feats=None,
bev_h=None,
bev_w=None,
bev_pos=None,
spatial_shapes=None,
level_start_index=None,
valid_ratios=None,
prev_bev=None,
shift=0.,
**kwargs):
"""Forward function for `TransformerDecoder`.
Args:
bev_query (Tensor): Input BEV query with shape
`(num_query, bs, embed_dims)`.
key & value (Tensor): Input multi-cameta features with shape
(num_cam, num_value, bs, embed_dims)
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
bev_embed = super().forward(
bev_query,
key,
value,
bev_h=bev_h,
bev_w=bev_w,
bev_pos=bev_pos,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
prev_bev=prev_bev,
shift=shift,
**kwargs)
# import ipdb; ipdb.set_trace()
images = mlvl_feats[0]
img_metas = kwargs['img_metas']
B, N, C, fH, fW = images.shape
lidar2img = []
camera2ego = []
camera_intrinsics = []
img_aug_matrix = []
lidar2ego = []
for img_meta in img_metas:
lidar2img.append(img_meta['lidar2img'])
camera2ego.append(img_meta['camera2ego'])
camera_intrinsics.append(img_meta['camera_intrinsics'])
img_aug_matrix.append(img_meta['img_aug_matrix'])
lidar2ego.append(img_meta['lidar2ego'])
lidar2img = np.asarray(lidar2img)
lidar2img = images.new_tensor(lidar2img) # (B, N, 4, 4)
camera2ego = np.asarray(camera2ego)
camera2ego = images.new_tensor(camera2ego) # (B, N, 4, 4)
camera_intrinsics = np.asarray(camera_intrinsics)
camera_intrinsics = images.new_tensor(camera_intrinsics) # (B, N, 4, 4)
img_aug_matrix = np.asarray(img_aug_matrix)
img_aug_matrix = images.new_tensor(img_aug_matrix) # (B, N, 4, 4)
lidar2ego = np.asarray(lidar2ego)
lidar2ego = images.new_tensor(lidar2ego) # (B, N, 4, 4)
rots = camera2ego[..., :3, :3]
trans = camera2ego[..., :3, 3]
intrins = camera_intrinsics[..., :3, :3]
post_rots = img_aug_matrix[..., :3, :3]
post_trans = img_aug_matrix[..., :3, 3]
lidar2ego_rots = lidar2ego[..., :3, :3]
lidar2ego_trans = lidar2ego[..., :3, 3]
mlp_input = self.get_mlp_input(camera2ego, camera_intrinsics, post_rots, post_trans)
depth = self.get_cam_feats(images, mlp_input)
ret_dict = dict(
bev=bev_embed['bev'],
depth=depth,
)
# import ipdb; ipdb.set_trace()
return ret_dict
@force_fp32()
def get_cam_feats(self, x, mlp_input):
B, N, C, fH, fW = x.shape
x = x.view(B * N, C, fH, fW)
x = self.depth_net(x, mlp_input)
depth = x[:, : self.D].softmax(dim=1)
depth = depth.view(B, N, self.D, fH, fW)
return depth
def get_downsampled_gt_depth(self, gt_depths):
"""
Input:
gt_depths: [B, N, H, W]
Output:
gt_depths: [B*N*h*w, d]
"""
B, N, H, W = gt_depths.shape
gt_depths = gt_depths.view(B * N, H // self.feat_down_sample,
self.feat_down_sample, W // self.feat_down_sample,
self.feat_down_sample, 1)
gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous()
gt_depths = gt_depths.view(-1, self.feat_down_sample * self.feat_down_sample)
# 把gt_depth做feat_down_sample倍数的采样
gt_depths_tmp = torch.where(gt_depths == 0.0,
1e5 * torch.ones_like(gt_depths),
gt_depths)
# 因为深度很稀疏,大部分的点都是0,所以把0变成10000,下一步取-1维度上的最小就是深度的值
gt_depths = torch.min(gt_depths_tmp, dim=-1).values
gt_depths = gt_depths.view(B * N, H // self.feat_down_sample,
W // self.feat_down_sample)
gt_depths = (
gt_depths -
(self.grid_config['depth'][0] -
self.grid_config['depth'][2])) / self.grid_config['depth'][2]
gt_depths = torch.where((gt_depths < self.D + 1) & (gt_depths >= 0.0),
gt_depths, torch.zeros_like(gt_depths))
gt_depths = F.one_hot(
gt_depths.long(), num_classes=self.D + 1).view(-1, self.D + 1)[:,
1:]
return gt_depths.float()
@force_fp32()
def get_depth_loss(self, depth_labels, depth_preds):
# import pdb;pdb.set_trace()
if depth_preds is None:
return 0
depth_labels = self.get_downsampled_gt_depth(depth_labels)
depth_preds = depth_preds.permute(0, 1, 3, 4, 2).contiguous().view(-1, self.D)
# fg_mask = torch.max(depth_labels, dim=1).values > 0.0 # 只计算有深度的前景的深度loss
# import pdb;pdb.set_trace()
fg_mask = depth_labels > 0.0 # 只计算有深度的前景的深度loss
depth_labels = depth_labels[fg_mask]
depth_preds = depth_preds[fg_mask]
with autocast(enabled=False):
depth_loss = F.binary_cross_entropy(
depth_preds,
depth_labels,
reduction='none',
).sum() / max(1.0, fg_mask.sum())
# if depth_loss <= 0.:
# import pdb;pdb.set_trace()
return self.loss_depth_weight * depth_loss
def get_mlp_input(self, sensor2ego, intrin, post_rot, post_tran):
B, N, _, _ = sensor2ego.shape
mlp_input = torch.stack([
intrin[:, :, 0, 0],
intrin[:, :, 1, 1],
intrin[:, :, 0, 2],
intrin[:, :, 1, 2],
post_rot[:, :, 0, 0],
post_rot[:, :, 0, 1],
post_tran[:, :, 0],
post_rot[:, :, 1, 0],
post_rot[:, :, 1, 1],
post_tran[:, :, 1],
], dim=-1)
sensor2ego = sensor2ego[:,:,:3,:].reshape(B, N, -1)
mlp_input = torch.cat([mlp_input, sensor2ego], dim=-1)
return mlp_input
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class LSSTransform(BaseTransform):
def __init__(
self,
in_channels,
out_channels,
feat_down_sample,
pc_range,
voxel_size,
dbound,
downsample=1,
loss_depth_weight = 3.0,
depthnet_cfg=dict(),
grid_config=None,
):
super(LSSTransform, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
feat_down_sample=feat_down_sample,
pc_range=pc_range,
voxel_size=voxel_size,
dbound=dbound,
)
# import pdb;pdb.set_trace()
self.loss_depth_weight = loss_depth_weight
self.grid_config = grid_config
self.depth_net = DepthNet(in_channels, in_channels,
self.C, self.D, **depthnet_cfg)
if downsample > 1:
assert downsample == 2, downsample
self.downsample = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(
out_channels,
out_channels,
3,
stride=downsample,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
else:
self.downsample = nn.Identity()
#@torch.compile(mode="max-autotune-no-cudagraphs")
@force_fp32()
def get_cam_feats(self, x, mlp_input):
B, N, C, fH, fW = x.shape
x = x.view(B * N, C, fH, fW)
x = self.depth_net(x, mlp_input)
depth = x[:, : self.D].softmax(dim=1)
x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)
x = x.view(B, N, self.C, self.D, fH, fW)
x = x.permute(0, 1, 3, 4, 5, 2)
depth = depth.view(B, N, self.D, fH, fW)
return x, depth
#@torch.compile(mode="max-autotune-no-cudagraphs")
def down_sample(self, x):
input = x
B, N, H, W, C = x.shape
x = x.permute(0, 4, 1, 2, 3).contiguous()
x = torch.cat(x.unbind(dim=2), 1)
x = x.permute(0,1,3,2).contiguous()
return self.downsample(x)
def forward(self, images, img_metas):
x, depth = super().forward(images, img_metas)
# x = self.downsample(x)
x = self.down_sample(x)
ret_dict = dict(
bev=x,
depth=depth,
)
return ret_dict
def get_downsampled_gt_depth(self, gt_depths):
"""
Input:
gt_depths: [B, N, H, W]
Output:
gt_depths: [B*N*h*w, d]
"""
B, N, H, W = gt_depths.shape
gt_depths = gt_depths.view(B * N, H // self.feat_down_sample,
self.feat_down_sample, W // self.feat_down_sample,
self.feat_down_sample, 1)
gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous()
gt_depths = gt_depths.view(-1, self.feat_down_sample * self.feat_down_sample)
# 把gt_depth做feat_down_sample倍数的采样
gt_depths_tmp = torch.where(gt_depths == 0.0,
1e5 * torch.ones_like(gt_depths),
gt_depths)
# 因为深度很稀疏,大部分的点都是0,所以把0变成10000,下一步取-1维度上的最小就是深度的值
gt_depths = torch.min(gt_depths_tmp, dim=-1).values
gt_depths = gt_depths.view(B * N, H // self.feat_down_sample,
W // self.feat_down_sample)
gt_depths = (
gt_depths -
(self.grid_config['depth'][0] -
self.grid_config['depth'][2])) / self.grid_config['depth'][2]
gt_depths = torch.where((gt_depths < self.D + 1) & (gt_depths >= 0.0),
gt_depths, torch.zeros_like(gt_depths))
gt_depths = F.one_hot(
gt_depths.long(), num_classes=self.D + 1).view(-1, self.D + 1)[:,
1:]
return gt_depths.float()
@force_fp32()
def get_depth_loss(self, depth_labels, depth_preds):
# import pdb;pdb.set_trace()
if depth_preds is None:
return 0
depth_labels = self.get_downsampled_gt_depth(depth_labels)
depth_preds = depth_preds.permute(0, 1, 3, 4, 2).contiguous().view(-1, self.D)
# fg_mask = torch.max(depth_labels, dim=1).values > 0.0 # 只计算有深度的前景的深度loss
# import pdb;pdb.set_trace()
fg_mask = depth_labels > 0.0 # 只计算有深度的前景的深度loss
depth_labels = depth_labels[fg_mask]
depth_preds = depth_preds[fg_mask]
with autocast(enabled=False):
depth_loss = F.binary_cross_entropy(
depth_preds,
depth_labels,
reduction='none',
).sum() / max(1.0, fg_mask.sum())
# if depth_loss <= 0.:
# import pdb;pdb.set_trace()
return self.loss_depth_weight * depth_loss
#@torch.compile(mode="max-autotune-no-cudagraphs")
def get_mlp_input(self, sensor2ego, intrin, post_rot, post_tran):
B, N, _, _ = sensor2ego.shape
mlp_input = torch.stack([
intrin[:, :, 0, 0],
intrin[:, :, 1, 1],
intrin[:, :, 0, 2],
intrin[:, :, 1, 2],
post_rot[:, :, 0, 0],
post_rot[:, :, 0, 1],
post_tran[:, :, 0],
post_rot[:, :, 1, 0],
post_rot[:, :, 1, 1],
post_tran[:, :, 1],
], dim=-1)
sensor2ego = sensor2ego[:,:,:3,:].reshape(B, N, -1)
mlp_input = torch.cat([mlp_input, sensor2ego], dim=-1)
return mlp_input
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class LSSTransformV2(BaseTransformV2):
def __init__(
self,
input_size,
in_channels,
out_channels,
feat_down_sample,
pc_range,
voxel_size,
dbound,
downsample=1,
loss_depth_weight = 3.0,
depthnet_cfg=dict(),
grid_config = None,
sid=False,
):
super(LSSTransformV2, self).__init__(
input_size=input_size,
in_channels=in_channels,
out_channels=out_channels,
feat_down_sample=feat_down_sample,
pc_range=pc_range,
voxel_size=voxel_size,
dbound=dbound,
sid=sid,
)
self.loss_depth_weight = loss_depth_weight
self.grid_config = grid_config
self.depth_net = DepthNet(self.in_channels, self.in_channels,
self.C, self.D, **depthnet_cfg)
if downsample > 1:
assert downsample == 2, downsample
self.downsample = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(
out_channels,
out_channels,
3,
stride=downsample,
padding=1,
bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(True),
)
else:
self.downsample = nn.Identity()
@force_fp32()
def get_cam_feats(self, x, mlp_input):
B, N, C, fH, fW = x.shape
x = x.view(B * N, C, fH, fW)
x = self.depth_net(x, mlp_input)
depth = x[:, : self.D].softmax(dim=1)
tran_feat = x[:, self.D : (self.D + self.C)]
tran_feat = tran_feat.view(B, N, self.C, fH, fW)
# x = x.permute(0, 1, 3, 4, 5, 2)
depth = depth.view(B, N, self.D, fH, fW)
return tran_feat, depth
def forward(self, images, img_metas):
x, depth = super().forward(images, img_metas)
x = self.downsample(x)
ret_dict = dict(
bev=x,
depth=depth,
)
return ret_dict
def get_downsampled_gt_depth(self, gt_depths):
"""
Input:
gt_depths: [B, N, H, W]
Output:
gt_depths: [B*N*h*w, d]
"""
B, N, H, W = gt_depths.shape
gt_depths = gt_depths.view(B * N, H // self.feat_down_sample,
self.feat_down_sample, W // self.feat_down_sample,
self.feat_down_sample, 1)
gt_depths = gt_depths.permute(0, 1, 3, 5, 2, 4).contiguous()
gt_depths = gt_depths.view(-1, self.feat_down_sample * self.feat_down_sample)
# 把gt_depth做feat_down_sample倍数的采样
gt_depths_tmp = torch.where(gt_depths == 0.0,
1e5 * torch.ones_like(gt_depths),
gt_depths)
# 因为深度很稀疏,大部分的点都是0,所以把0变成10000,下一步取-1维度上的最小就是深度的值
gt_depths = torch.min(gt_depths_tmp, dim=-1).values
gt_depths = gt_depths.view(B * N, H // self.feat_down_sample,
W // self.feat_down_sample)
gt_depths = (
gt_depths -
(self.grid_config['depth'][0] -
self.grid_config['depth'][2])) / self.grid_config['depth'][2]
gt_depths = torch.where((gt_depths < self.D + 1) & (gt_depths >= 0.0),
gt_depths, torch.zeros_like(gt_depths))
gt_depths = F.one_hot(
gt_depths.long(), num_classes=self.D + 1).view(-1, self.D + 1)[:,
1:]
return gt_depths.float()
@force_fp32()
def get_depth_loss(self, depth_labels, depth_preds):
# import pdb;pdb.set_trace()
if depth_preds is None:
return 0
depth_labels = self.get_downsampled_gt_depth(depth_labels)
depth_preds = depth_preds.permute(0, 1, 3, 4, 2).contiguous().view(-1, self.D)
# fg_mask = torch.max(depth_labels, dim=1).values > 0.0 # 只计算有深度的前景的深度loss
# import pdb;pdb.set_trace()
fg_mask = depth_labels > 0.0 # 只计算有深度的前景的深度loss
depth_labels = depth_labels[fg_mask]
depth_preds = depth_preds[fg_mask]
with autocast(enabled=False):
depth_loss = F.binary_cross_entropy(
depth_preds,
depth_labels,
reduction='none',
).sum() / max(1.0, fg_mask.sum())
# if depth_loss <= 0.:
# import pdb;pdb.set_trace()
return self.loss_depth_weight * depth_loss
def get_mlp_input(self, sensor2ego, intrin, post_rot, post_tran):
B, N, _, _ = sensor2ego.shape
mlp_input = torch.stack([
intrin[:, :, 0, 0],
intrin[:, :, 1, 1],
intrin[:, :, 0, 2],
intrin[:, :, 1, 2],
post_rot[:, :, 0, 0],
post_rot[:, :, 0, 1],
post_tran[:, :, 0],
post_rot[:, :, 1, 0],
post_rot[:, :, 1, 1],
post_tran[:, :, 1],
], dim=-1)
sensor2ego = sensor2ego[:,:,:3,:].reshape(B, N, -1)
mlp_input = torch.cat([mlp_input, sensor2ego], dim=-1)
return mlp_input
class _ASPPModule(nn.Module):
def __init__(self, inplanes, planes, kernel_size, padding, dilation,
BatchNorm):
super(_ASPPModule, self).__init__()
self.atrous_conv = nn.Conv2d(
inplanes,
planes,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=False)
self.bn = BatchNorm(planes)
self.relu = nn.ReLU()
self._init_weight()
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class ASPP(nn.Module):
def __init__(self, inplanes, mid_channels=256, BatchNorm=nn.BatchNorm2d):
super(ASPP, self).__init__()
dilations = [1, 6, 12, 18]
self.aspp1 = _ASPPModule(
inplanes,
mid_channels,
1,
padding=0,
dilation=dilations[0],
BatchNorm=BatchNorm)
self.aspp2 = _ASPPModule(
inplanes,
mid_channels,
3,
padding=dilations[1],
dilation=dilations[1],
BatchNorm=BatchNorm)
self.aspp3 = _ASPPModule(
inplanes,
mid_channels,
3,
padding=dilations[2],
dilation=dilations[2],
BatchNorm=BatchNorm)
self.aspp4 = _ASPPModule(
inplanes,
mid_channels,
3,
padding=dilations[3],
dilation=dilations[3],
BatchNorm=BatchNorm)
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(inplanes, mid_channels, 1, stride=1, bias=False),
BatchNorm(mid_channels),
nn.ReLU(),
)
self.conv1 = nn.Conv2d(
int(mid_channels * 5), inplanes, 1, bias=False)
self.bn1 = BatchNorm(inplanes)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self._init_weight()
def forward(self, x):
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.interpolate(
x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return self.dropout(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
import warnings
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import xavier_init, constant_init
from mmcv.cnn.bricks.registry import (ATTENTION,
TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from mmcv.cnn.bricks.transformer import build_attention
import math
from mmcv.runner import force_fp32, auto_fp16
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from projects.mmdet3d_plugin.models.utils.bricks import run_time
from .ops.geometric_kernel_attn import GeometricKernelAttentionFunc
@ATTENTION.register_module()
class GeometrySptialCrossAttention(BaseModule):
"""An attention module used in BEVFormer.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_cams (int): The number of cameras
dropout (float): A Dropout layer on `inp_residual`.
Default: 0..
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
deformable_attention: (dict): The config for the deformable attention used in SCA.
"""
def __init__(self,
embed_dims=256,
num_cams=6,
pc_range=None,
dropout=0.1,
init_cfg=None,
batch_first=False,
attention=dict(
type='MSDeformableAttention3D',
embed_dims=256,
num_levels=4),
**kwargs
):
super(GeometrySptialCrossAttention, self).__init__(init_cfg)
self.init_cfg = init_cfg
self.dropout = nn.Dropout(dropout)
self.pc_range = pc_range
self.fp16_enabled = False
self.attention = build_attention(attention)
self.embed_dims = embed_dims
self.num_cams = num_cams
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.batch_first = batch_first
self.init_weight()
def init_weight(self):
"""Default initialization for Parameters of Module."""
xavier_init(self.output_proj, distribution='uniform', bias=0.)
@force_fp32(apply_to=('query', 'key', 'value', 'query_pos', 'reference_points_cam'))
def forward(self,
query,
key,
value,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
reference_points_cam=None,
bev_mask=None,
level_start_index=None,
flag='encoder',
**kwargs):
"""Forward Function of Detr3DCrossAtten.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`. (B, N, C, H, W)
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, 4),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different level. With shape (num_levels, 2),
last dimension represent (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if residual is None:
inp_residual = query
slots = torch.zeros_like(query)
if query_pos is not None:
query = query + query_pos
bs, num_query, _ = query.size()
D = reference_points_cam.size(3)
indexes = []
for i, mask_per_img in enumerate(bev_mask):
index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1)
indexes.append(index_query_per_img)
max_len = max([len(each) for each in indexes])
# each camera only interacts with its corresponding BEV queries. This step can greatly save GPU memory.
queries_rebatch = query.new_zeros(
[bs, self.num_cams, max_len, self.embed_dims])
reference_points_rebatch = reference_points_cam.new_zeros(
[bs, self.num_cams, max_len, D, 2])
for j in range(bs):
for i, reference_points_per_img in enumerate(reference_points_cam):
index_query_per_img = indexes[i]
queries_rebatch[j, i, :len(
index_query_per_img)] = query[j, index_query_per_img]
reference_points_rebatch[j, i, :len(
index_query_per_img)] = reference_points_per_img[j, index_query_per_img]
num_cams, l, bs, embed_dims = key.shape
key = key.permute(2, 0, 1, 3).reshape(
bs * self.num_cams, l, self.embed_dims)
value = value.permute(2, 0, 1, 3).reshape(
bs * self.num_cams, l, self.embed_dims)
queries = self.attention(query=queries_rebatch.view(bs*self.num_cams, max_len, self.embed_dims), key=key, value=value,
reference_points=reference_points_rebatch.view(bs*self.num_cams, max_len, D, 2), spatial_shapes=spatial_shapes,
level_start_index=level_start_index).view(bs, self.num_cams, max_len, self.embed_dims)
for j in range(bs):
for i, index_query_per_img in enumerate(indexes):
slots[j, index_query_per_img] += queries[j,
i, :len(index_query_per_img)]
count = bev_mask.sum(-1) > 0
count = count.permute(1, 2, 0).sum(-1)
count = torch.clamp(count, min=1.0)
slots = slots / count[..., None]
slots = self.output_proj(slots)
return self.dropout(slots) + inp_residual
@ATTENTION.register_module()
class GeometryKernelAttention(BaseModule):
"""An attention module used in BEVFormer based on Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
kernel_size=(3, 3),
dilation=1,
im2col_step=64,
dropout=0.1,
batch_first=True,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.batch_first = batch_first
self.output_proj = None
self.fp16_enabled = False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
# 4
self.num_levels = num_levels
# 4 num_heads -> num_z_anchors
self.num_heads = num_heads
self.kernel_size = kernel_size
self.num_points = kernel_size[0] * kernel_size[1]
# self.sampling_offsets = nn.Linear(
# embed_dims, num_heads * num_levels * self.num_points * 2)
self.attention_weights = nn.Linear(
embed_dims, num_levels * self.num_points * self.num_heads)
self.value_proj = nn.Linear(embed_dims, embed_dims)
grid_h, grid_w = kernel_size
y = (torch.arange(grid_h) - grid_h // 2) * dilation
x = (torch.arange(grid_w) - grid_w // 2) * dilation
offsets = torch.stack(
torch.meshgrid(x, y)).permute(1, 2, 0).reshape(grid_h * grid_w, 2)
self.register_buffer("grid_offsets", offsets, persistent=False)
self.init_weights()
def init_weights(self):
"""Default initialization for Parameters of Module."""
# constant_init(self.sampling_offsets, 0.)
# thetas = torch.arange(
# self.num_heads,
# dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
# grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
# grid_init = (grid_init /
# grid_init.abs().max(-1, keepdim=True)[0]).view(
# self.num_heads, 1, 1,
# 2).repeat(1, self.num_levels, self.num_points, 1)
# for i in range(self.num_points):
# grid_init[:, :, i, :] *= i + 1
# self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True
def forward_kernel_multihead_attention(self, value, spatial_shapes, sampling_locations, attention_weights):
# value: (bs, n, d)
"""CPU version of multi-scale deformable attention.
Args:
value (Tensor): The value has shape
(bs, num_keys, dim)
spatial_shapes (Tensor): Spatial shape of
each feature map, has shape (num_levels, 2),
last dimension 2 represent (h, w)
sampling_locations (Tensor): The location of sampling points,
has shape
(bs ,num_queries, num_levels, num_points, 2),
the last dimension 2 represent (x, y).
attention_weights (Tensor): The weight of sampling points used
when calculate the attention, has shape
(bs ,num_queries, num_levels, num_points),
Returns:
Tensor: has shape (bs, num_queries, embed_dims)
"""
# print(value.shape, sampling_locations.shape, attention_weights.shape)
# print(value.shape)
bs, num_keys, num_heads, dim = value.shape
# (bs * num_heads * num_keys, d)
# torch.cuda.synchronize()
# start2 = time.perf_counter()
value = value.transpose(1, 2).contiguous().view(
bs * num_heads * num_keys, dim)
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
with torch.no_grad():
sampling_index = sampling_locations.new_zeros(
(bs, num_queries, num_heads, num_levels, num_points)).to(value.device)
start_index = 0
for level, (H_, W_) in enumerate(spatial_shapes):
# xy or yx?
sampling_locations[:, :, :, level,
:, 0].clamp_(min=0, max=W_-1)
sampling_locations[:, :, :, level,
:, 1].clamp_(min=0, max=H_-1)
sampling_index[:, :, :, level] = start_index + sampling_locations[:, :, :, level, :, 0] \
+ sampling_locations[:, :, :, level, :, 1] * W_
start_index += H_ * W_
# print(start_index)
# head index, (bs, head, num_quries,)
sampling_index = sampling_index.transpose(
1, 2).reshape(bs, num_heads, -1)
sampling_index = sampling_index + \
(torch.arange(num_heads).to(sampling_index)
* num_keys).view(1, num_heads, 1)
# batch index
sampling_index = sampling_index.reshape(
bs, -1) + (torch.arange(bs).to(sampling_index) * num_keys * num_heads).view(bs, 1)
# torch.cuda.synchronize()
# end = time.perf_counter()
# print("geometric kernel attention (index): {:.3f} ms".format(
# (end-start)*1000))
# torch.cuda.synchronize()
# start = time.perf_counter()
sampling_value = value[sampling_index].view(
bs, num_heads, num_queries, num_levels * num_points, dim)
# print(sampling_value.shape)
attention_weights = attention_weights.transpose(1, 2).contiguous().view(
bs, num_heads, num_queries, num_levels * num_points, 1)
# torch.cuda.synchronize()
# end = time.perf_counter()
# print("geometric kernel attention (sample): {:.3f} ms".format(
# (end-start)*1000))
# # (bs*head, num_queries, num_levels * num_points, d) -> (bs, head, num_queries, d)
# torch.cuda.synchronize()
# start = time.perf_counter()
output = (sampling_value *
attention_weights).sum(-2).transpose(1, 2).contiguous()
# torch.cuda.synchronize()
# end = time.perf_counter()
# print("geometric kernel attention (matmul): {:.3f} ms".format(
# (end-start)*1000))
# print('x;', output.shape)
return output.view(bs, num_queries, -1)
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
( bs, num_query, embed_dims).
key (Tensor): The key tensor with shape
`(bs, num_key, embed_dims)`.
value (Tensor): The value tensor with shape
`(bs, num_key, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if value is None:
value = query
if identity is None:
identity = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_value, self.num_heads, -1)
# sampling_offsets = self.sampling_offsets(query).view(
# bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
# bs, num_query, num_heads, num_levels, num_points
# bs, q, 4, 4, K^2
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
"""
For each BEV query, it owns `num_Z_anchors` in 3D space that having different heights.
After proejcting, each BEV query has `num_Z_anchors` reference points in each 2D image.
For each referent point, we sample `num_points` sampling points.
For `num_Z_anchors` reference points, it has overall `num_points * num_Z_anchors` sampling points.
"""
with torch.no_grad():
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
bs, num_query, num_Z_anchors, xy = reference_points.shape
# from IPython import embed; embed()
# (K,2) -> (1, 1, 1, 1, k, 2) -> (bs, q, nz, l, k, 2)
offsets = self.grid_offsets[None, None, None, None]
# (bs, q, nz, 1, xy) -> (bs, q, z, l, 2)
reference_points = reference_points[:,
:, :, None, :] * offset_normalizer
# from IPython import embed;embed()
# (bs, q, nz, l, k, xy)
sampling_locations = (
reference_points[:, :, :, :, None, :] + offsets).round().long()
# sampling_offsets = sampling_offsets / \
# offset_normalizer[None, None, None, :, None, :]
# (bs, q, 4(z), 4, K^2, 2)
bs, num_query, num_heads, num_levels, num_all_points, xy = sampling_locations.shape
# sampling_offsets = sampling_offsets.view(
# bs, num_query, num_heads, num_levels, num_all_points // num_Z_anchors, num_Z_anchors, xy)
# sampling_locations = reference_points + sampling_offsets
# bs, num_query, num_heads, num_levels, num_points, num_Z_anchors, xy = sampling_locations.shape
# assert num_all_points == num_points * num_Z_anchors
# sampling_locations = sampling_locations.view(
# bs, num_query, num_heads, num_levels, num_all_points, xy)
elif reference_points.shape[-1] == 4:
assert False
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
# sampling_locations.shape: bs, num_query, num_heads, num_levels, num_all_points, 2
# attention_weights.shape: bs, num_query, num_heads, num_levels, num_all_points
# import pdb;pdb.set_trace()
# output = self.forward_kernel_multihead_attention(
# value, spatial_shapes, sampling_locations, attention_weights)
# torch.cuda.synchronize()
# start = time.perf_counter()
output = GeometricKernelAttentionFunc.apply(
value, spatial_shapes, level_start_index, sampling_locations.contiguous(), attention_weights, self.im2col_step
)
# if torch.cuda.is_available() and value.is_cuda:
# if value.dtype == torch.float16:
# MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
# else:
# MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
# output = MultiScaleDeformableAttnFunction.apply(
# value, spatial_shapes, level_start_index, sampling_locations,
# attention_weights, self.im2col_step)
# else:
# output = multi_scale_deformable_attn_pytorch(
# value, spatial_shapes, sampling_locations, attention_weights)
if not self.batch_first:
output = output.permute(1, 0, 2)
# torch.cuda.synchronize()
# end = time.perf_counter()
# print("geometric kernel attention: {:.3f} ms".format((end-start)*1000))
return output
from .function import GeometricKernelAttentionFunc
\ No newline at end of file
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import GeometricKernelAttention as GKA
class GeometricKernelAttentionFunc(Function):
@staticmethod
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
ctx.im2col_step = im2col_step
output = GKA.geometric_kernel_attn_cuda_forward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
ctx.save_for_backward(value, value_spatial_shapes,
value_level_start_index, sampling_locations, attention_weights)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
grad_value, grad_attn_weight = \
GKA.geometric_kernel_attn_cuda_backward(
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
return grad_value, None, None, None, grad_attn_weight, None
import os
import glob
import torch
from torch.utils.cpp_extension import CUDA_HOME
from torch.utils.cpp_extension import CppExtension
from torch.utils.cpp_extension import CUDAExtension
from setuptools import find_packages
from setuptools import setup
requirements = ["torch", "torchvision"]
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
# source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "*.cu"))
sources = main_file
extension = CppExtension
extra_compile_args = {"cxx": []}
define_macros = []
if 1:
# if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
else:
raise NotImplementedError('Cuda is not availabel')
sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
ext_modules = [
extension(
"GeometricKernelAttention",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]
return ext_modules
setup(
name="GeometricKernelAttention",
version="1.0",
author="Tianheng Cheng",
url="https://github.com/hustvl",
description="PyTorch Wrapper for CUDA Functions of Multi-Scale Geometric Kernel Attention",
packages=find_packages(exclude=("configs", "tests",)),
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)
#pragma once
// #include "cpu/ms_deform_attn_cpu.h"
// #ifdef WITH_CUDA
#include "geometric_kernel_attn_cuda.h"
at::Tensor
geometric_kernel_attn_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step)
{
if (value.type().is_cuda())
{
return geometric_kernel_attn_cuda_forward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
}
AT_ERROR("Not implemented on the CPU");
}
std::vector<at::Tensor>
geometric_kernel_attn_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step)
{
if (value.type().is_cuda())
{
return geometric_kernel_attn_cuda_backward(
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
}
AT_ERROR("Not implemented on the CPU");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <THC/THCAtomics.cuh>
#include <vector>
#include "geometric_kernel_attn_cuda_kernel.cuh"
at::Tensor geometric_kernel_attn_cuda_forward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "multiscale_kernel_attn_forward_cuda", ([&] {
multiscale_kernel_attn_forward_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<int64_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads*channels});
return output;
}
std::vector<at::Tensor> geometric_kernel_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const at::Tensor &grad_output,
const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
auto grad_value = at::zeros_like(value);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "multiscale_kernel_attn_backward_cuda", ([&] {
multiscale_kernel_attn_backward_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<int64_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}
return {
grad_value, grad_attn_weight
};
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment