Commit 3b8d508a authored by lishj6's avatar lishj6 🏸
Browse files

init_0905

parent e968ab0f
Pipeline #2906 canceled with stages
# Copyright (c) Phigent Robotics. All rights reserved.
import torch.utils.checkpoint as checkpoint
from torch import nn
import torch
from mmcv.cnn.bricks.conv_module import ConvModule
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
from mmdet3d.models import BACKBONES
@BACKBONES.register_module()
class CustomResNet(nn.Module):
def __init__(
self,
numC_input,
num_layer=[2, 2, 2],
num_channels=None,
stride=[2, 2, 2],
backbone_output_ids=None,
norm_cfg=dict(type='BN'),
with_cp=False,
block_type='Basic',
):
super(CustomResNet, self).__init__()
# build backbone
assert len(num_layer) == len(stride)
num_channels = [numC_input*2**(i+1) for i in range(len(num_layer))] \
if num_channels is None else num_channels
self.backbone_output_ids = range(len(num_layer)) \
if backbone_output_ids is None else backbone_output_ids
layers = []
if block_type == 'BottleNeck':
curr_numC = numC_input
for i in range(len(num_layer)):
# 在第一个block中对输入进行downsample
layer = [Bottleneck(inplanes=curr_numC, planes=num_channels[i]//4, stride=stride[i],
downsample=nn.Conv2d(curr_numC, num_channels[i], 3, stride[i], 1),
norm_cfg=norm_cfg)]
curr_numC = num_channels[i]
layer.extend([Bottleneck(inplanes=curr_numC, planes=num_channels[i]//4, stride=1,
downsample=None, norm_cfg=norm_cfg) for _ in range(num_layer[i] - 1)])
layers.append(nn.Sequential(*layer))
elif block_type == 'Basic':
curr_numC = numC_input
for i in range(len(num_layer)):
# 在第一个block中对输入进行downsample
layer = [BasicBlock(inplanes=curr_numC, planes=num_channels[i], stride=stride[i],
downsample=nn.Conv2d(curr_numC, num_channels[i], 3, stride[i], 1),
norm_cfg=norm_cfg)]
curr_numC = num_channels[i]
layer.extend([BasicBlock(inplanes=curr_numC, planes=num_channels[i], stride=1,
downsample=None, norm_cfg=norm_cfg) for _ in range(num_layer[i] - 1)])
layers.append(nn.Sequential(*layer))
else:
assert False
self.layers = nn.Sequential(*layers)
self.with_cp = with_cp
#@torch.compile
def forward(self, x):
"""
Args:
x: (B, C=64, Dy, Dx)
Returns:
feats: List[
(B, 2*C, Dy/2, Dx/2),
(B, 4*C, Dy/4, Dx/4),
(B, 8*C, Dy/8, Dx/8),
]
"""
feats = []
x_tmp = x
for lid, layer in enumerate(self.layers):
if self.with_cp:
x_tmp = checkpoint.checkpoint(layer, x_tmp)
else:
x_tmp = layer(x_tmp)
if lid in self.backbone_output_ids:
feats.append(x_tmp)
return feats
class BasicBlock3D(nn.Module):
def __init__(self,
channels_in, channels_out, stride=1, downsample=None):
super(BasicBlock3D, self).__init__()
self.conv1 = ConvModule(
channels_in,
channels_out,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d', ),
act_cfg=dict(type='ReLU',inplace=True))
self.conv2 = ConvModule(
channels_out,
channels_out,
kernel_size=3,
stride=1,
padding=1,
bias=False,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d', ),
act_cfg=None)
self.downsample = downsample
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
if self.downsample is not None:
identity = self.downsample(x)
else:
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = x + identity
return self.relu(x)
@BACKBONES.register_module()
class CustomResNet3D(nn.Module):
def __init__(
self,
numC_input,
num_layer=[2, 2, 2],
num_channels=None,
stride=[2, 2, 2],
backbone_output_ids=None,
with_cp=False,
):
super(CustomResNet3D, self).__init__()
# build backbone
assert len(num_layer) == len(stride)
num_channels = [numC_input * 2 ** (i + 1) for i in range(len(num_layer))] \
if num_channels is None else num_channels
self.backbone_output_ids = range(len(num_layer)) \
if backbone_output_ids is None else backbone_output_ids
layers = []
curr_numC = numC_input
for i in range(len(num_layer)):
layer = [
BasicBlock3D(
curr_numC,
num_channels[i],
stride=stride[i],
downsample=ConvModule(
curr_numC,
num_channels[i],
kernel_size=3,
stride=stride[i],
padding=1,
bias=False,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d', ),
act_cfg=None))
]
curr_numC = num_channels[i]
layer.extend([
BasicBlock3D(curr_numC, curr_numC)
for _ in range(num_layer[i] - 1)
])
layers.append(nn.Sequential(*layer))
self.layers = nn.Sequential(*layers)
self.with_cp = with_cp
def forward(self, x):
"""
Args:
x: (B, C, Dz, Dy, Dx)
Returns:
feats: List[
(B, C, Dz, Dy, Dx),
(B, 2C, Dz/2, Dy/2, Dx/2),
(B, 4C, Dz/4, Dy/4, Dx/4),
]
"""
feats = []
x_tmp = x
for lid, layer in enumerate(self.layers):
if self.with_cp:
x_tmp = checkpoint.checkpoint(layer, x_tmp)
else:
x_tmp = layer(x_tmp)
if lid in self.backbone_output_ids:
feats.append(x_tmp)
return feats
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer, trunc_normal_init, build_conv_layer
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmcv.cnn.utils.weight_init import constant_init
from mmcv.runner import _load_checkpoint
from mmcv.runner.base_module import BaseModule, ModuleList
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
import torch.utils.checkpoint as checkpoint
from mmseg.ops import resize
from mmdet3d.utils import get_root_logger
from mmdet3d.models.builder import BACKBONES
from mmcv.cnn.bricks.registry import ATTENTION
from torch.nn.modules.utils import _pair as to_2tuple
from collections import OrderedDict
def swin_convert(ckpt):
new_ckpt = OrderedDict()
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
for k, v in ckpt.items():
if k.startswith('head'):
continue
elif k.startswith('layers'):
new_v = v
if 'attn.' in k:
new_k = k.replace('attn.', 'attn.w_msa.')
elif 'mlp.' in k:
if 'mlp.fc1.' in k:
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
elif 'mlp.fc2.' in k:
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
else:
new_k = k.replace('mlp.', 'ffn.')
elif 'downsample' in k:
new_k = k
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(v)
else:
new_k = k
new_k = new_k.replace('layers', 'stages', 1)
elif k.startswith('patch_embed'):
new_v = v
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
else:
new_v = v
new_k = k
new_ckpt[new_k] = new_v
return new_ckpt
# Modified from Pytorch-Image-Models
class PatchEmbed(BaseModule):
"""Image to Patch Embedding V2.
We use a conv layer to implement PatchEmbed.
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (dict, optional): The config dict for conv layers type
selection. Default: None.
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int): The slide stride of embedding conv.
Default: None (Default to be equal with kernel_size).
padding (int): The padding length of embedding conv. Default: 0.
dilation (int): The dilation rate of embedding conv. Default: 1.
pad_to_patch_size (bool, optional): Whether to pad feature map shape
to multiple patch size. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=768,
conv_type=None,
kernel_size=16,
stride=16,
padding=0,
dilation=1,
pad_to_patch_size=True,
norm_cfg=None,
init_cfg=None):
super(PatchEmbed, self).__init__()
self.embed_dims = embed_dims
self.init_cfg = init_cfg
if stride is None:
stride = kernel_size
self.pad_to_patch_size = pad_to_patch_size
# The default setting of patch size is equal to kernel size.
patch_size = kernel_size
if isinstance(patch_size, int):
patch_size = to_2tuple(patch_size)
elif isinstance(patch_size, tuple):
if len(patch_size) == 1:
patch_size = to_2tuple(patch_size[0])
assert len(patch_size) == 2, \
f'The size of patch should have length 1 or 2, ' \
f'but got {len(patch_size)}'
self.patch_size = patch_size
# Use conv layer to embed
conv_type = conv_type or 'Conv2d'
self.projection = build_conv_layer(
dict(type=conv_type),
in_channels=in_channels,
out_channels=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
def forward(self, x):
H, W = x.shape[2], x.shape[3]
# TODO: Process overlapping op
if self.pad_to_patch_size:
# Modify H, W to multiple of patch size.
if H % self.patch_size[0] != 0:
x = F.pad(
x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
if W % self.patch_size[1] != 0:
x = F.pad(
x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0))
x = self.projection(x)
self.DH, self.DW = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer use nn.Unfold to group feature map by kernel_size, and use norm
and linear layer to embed grouped feature map.
Args:
in_channels (int): The num of input channels.
out_channels (int): The num of output channels.
stride (int | tuple): the stride of the sliding length in the
unfold layer. Defaults: 2. (Default to be equal with kernel_size).
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Defaults: None.
"""
def __init__(self,
in_channels,
out_channels,
stride=2,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.sampler = nn.Unfold(
kernel_size=stride, dilation=1, padding=0, stride=stride)
sample_dim = stride**2 * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, hw_shape):
"""
x: x.shape -> [B, H*W, C]
hw_shape: (H, W)
"""
B, L, C = x.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# stride is fixed to be equal to kernel_size.
if (H % self.stride != 0) or (W % self.stride != 0):
x = F.pad(x, (0, W % self.stride, 0, H % self.stride))
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
x = self.sampler(x) # B, 4*C, H/2*W/2
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
down_hw_shape = (H + 1) // 2, (W + 1) // 2
return x, down_hw_shape
@ATTENTION.register_module()
class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.0
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
init_cfg=None):
super().__init__()
self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
self.init_cfg = init_cfg
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# About 2x faster than original impl
Wh, Ww = self.window_size
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
rel_position_index = rel_index_coords + rel_index_coords.T
rel_position_index = rel_position_index.flip(1).contiguous()
self.register_buffer('relative_position_index', rel_position_index)
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
self.softmax = nn.Softmax(dim=-1)
def init_weights(self):
trunc_normal_init(self.relative_position_bias_table, std=0.02)
def forward(self, x, mask=None):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor | None, Optional): mask with shape of (num_windows,
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@staticmethod
def double_step_seq(step1, len1, step2, len2):
seq1 = torch.arange(0, step1 * len1, step1)
seq2 = torch.arange(0, step2 * len2, step2)
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
@ATTENTION.register_module()
class ShiftWindowMSA(BaseModule):
"""Shift Window Multihead Self-Attention Module.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window.
shift_size (int, optional): The shift step of each window towards
right-bottom. If zero, act as regular window-msa. Defaults to 0.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Defaults: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Defaults: 0.
proj_drop_rate (float, optional): Dropout ratio of output.
Defaults: 0.
dropout_layer (dict, optional): The dropout_layer used before output.
Defaults: dict(type='DropPath', drop_prob=0.).
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
shift_size=0,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0,
proj_drop_rate=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
init_cfg=None):
super().__init__(init_cfg)
self.window_size = window_size
self.shift_size = shift_size
assert 0 <= self.shift_size < self.window_size
self.w_msa = WindowMSA(
embed_dims=embed_dims,
num_heads=num_heads,
window_size=to_2tuple(window_size),
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=proj_drop_rate,
init_cfg=None)
self.drop = build_dropout(dropout_layer)
def forward(self, query, hw_shape):
B, L, C = query.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
query = query.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
H_pad, W_pad = query.shape[1], query.shape[2]
# cyclic shift
if self.shift_size > 0:
shifted_query = torch.roll(
query,
shifts=(-self.shift_size, -self.shift_size),
dims=(1, 2))
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, H_pad, W_pad, 1),
device=query.device) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# nW, window_size, window_size, 1
mask_windows = self.window_partition(img_mask)
mask_windows = mask_windows.view(
-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
else:
shifted_query = query
attn_mask = None
# nW*B, window_size, window_size, C
query_windows = self.window_partition(shifted_query)
# nW*B, window_size*window_size, C
query_windows = query_windows.view(-1, self.window_size**2, C)
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows = self.w_msa(query_windows, mask=attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size,
self.window_size, C)
# B H' W' C
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = self.drop(x)
return x
def window_reverse(self, windows, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
window_size = self.window_size
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
def window_partition(self, x):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
window_size = self.window_size
x = x.view(B, H // window_size, window_size, W // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
class SwinBlock(BaseModule):
""""
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
window size (int, optional): The local window scale. Default: 7.
shift (bool): whether to shift window or not. Default False.
qkv_bias (int, optional): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization.
Default: dict(type='LN').
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
window_size=7,
shift=False,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(SwinBlock, self).__init__()
self.init_cfg = init_cfg
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = ShiftWindowMSA(
embed_dims=embed_dims,
num_heads=num_heads,
window_size=window_size,
shift_size=window_size // 2 if shift else 0,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
init_cfg=None)
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=2,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
add_identity=True,
init_cfg=None)
def forward(self, x, hw_shape):
identity = x
x = self.norm1(x)
x = self.attn(x, hw_shape)
x = x + identity
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
class SwinBlockSequence(BaseModule):
"""Implements one stage in Swin Transformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
depth (int): The number of blocks in this stage.
window size (int): The local window scale. Default: 7.
qkv_bias (int): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
downsample (BaseModule | None, optional): The downsample operation
module. Default: None.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization.
Default: dict(type='LN').
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
depth,
window_size=7,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
downsample=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None,
with_cp=True):
super().__init__()
self.init_cfg = init_cfg
drop_path_rate = drop_path_rate if isinstance(
drop_path_rate,
list) else [deepcopy(drop_path_rate) for _ in range(depth)]
self.blocks = ModuleList()
for i in range(depth):
block = SwinBlock(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=feedforward_channels,
window_size=window_size,
shift=False if i % 2 == 0 else True,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate[i],
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=None)
self.blocks.append(block)
self.downsample = downsample
self.with_cp = with_cp
def forward(self, x, hw_shape):
for block in self.blocks:
if self.with_cp:
x = checkpoint.checkpoint(block, x, hw_shape)
else:
x = block(x, hw_shape)
if self.downsample:
x_down, down_hw_shape = self.downsample(x, hw_shape)
return x_down, down_hw_shape, x, hw_shape
else:
return x, hw_shape, x, hw_shape
@BACKBONES.register_module()
class SwinTransformer(BaseModule):
""" Swin Transformer
A PyTorch implement of : `Swin Transformer:
Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/abs/2103.14030
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
pretrain_img_size (int | tuple[int]): The size of input image when
pretrain. Defaults: 224.
in_channels (int): The num of input channels.
Defaults: 3.
embed_dims (int): The feature dimension. Default: 96.
patch_size (int | tuple[int]): Patch size. Default: 4.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
Default: 4.
depths (tuple[int]): Depths of each Swin Transformer stage.
Default: (2, 2, 6, 2).
num_heads (tuple[int]): Parallel attention heads of each Swin
Transformer stage. Default: (3, 6, 12, 24).
strides (tuple[int]): The patch merging or patch embedding stride of
each Swin Transformer stage. (In swin, we set kernel size equal to
stride.) Default: (4, 2, 2, 2).
out_indices (tuple[int]): Output from which stages.
Default: (0, 1, 2, 3).
qkv_bias (bool, optional): If True, add a learnable bias to query, key,
value. Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
patch_norm (bool): If add a norm layer for patch embed and patch
merging. Default: True.
drop_rate (float): Dropout rate. Defaults: 0.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults: False.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LN').
norm_cfg (dict): Config dict for normalization layer at
output of backone. Defaults: dict(type='LN').
pretrain_style (str): Choose to use official or mmcls pretrain weights.
Default: official.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
pretrain_img_size=224,
in_channels=3,
embed_dims=96,
patch_size=4,
window_size=7,
mlp_ratio=4,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
strides=(4, 2, 2, 2),
out_indices=(0, 1, 2, 3),
qkv_bias=True,
qk_scale=None,
patch_norm=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
use_abs_pos_embed=False,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
pretrain_style='official',
pretrained=None,
init_cfg=None,
with_cp=True,
return_stereo_feat=False,
output_missing_index_as_none=False,
frozen_stages=-1):
super(SwinTransformer, self).__init__()
if isinstance(pretrain_img_size, int):
pretrain_img_size = to_2tuple(pretrain_img_size)
elif isinstance(pretrain_img_size, tuple):
if len(pretrain_img_size) == 1:
pretrain_img_size = to_2tuple(pretrain_img_size[0])
assert len(pretrain_img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(pretrain_img_size)}'
assert pretrain_style in ['official', 'mmcls'], 'We only support load '
'official ckpt and mmcls ckpt.'
if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
else:
raise TypeError('pretrained must be a str or None')
num_layers = len(depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.pretrain_style = pretrain_style
self.pretrained = pretrained
self.init_cfg = init_cfg
self.frozen_stages = frozen_stages
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=strides[0],
pad_to_patch_size=True,
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
if self.use_abs_pos_embed:
patch_row = pretrain_img_size[0] // patch_size
patch_col = pretrain_img_size[1] // patch_size
num_patches = patch_row * patch_col
self.absolute_pos_embed = nn.Parameter(
torch.zeros((1, num_patches, embed_dims)))
self.drop_after_pos = nn.Dropout(p=drop_rate)
# stochastic depth
total_depth = sum(depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
self.stages = ModuleList()
in_channels = embed_dims
for i in range(num_layers):
if i < num_layers - 1:
downsample = PatchMerging(
in_channels=in_channels,
out_channels=2 * in_channels,
stride=strides[i + 1],
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
else:
downsample = None
stage = SwinBlockSequence(
embed_dims=in_channels,
num_heads=num_heads[i],
feedforward_channels=mlp_ratio * in_channels,
depth=depths[i],
window_size=window_size,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[:depths[i]],
downsample=downsample,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=None,
with_cp=with_cp)
self.stages.append(stage)
dpr = dpr[depths[i]:]
if downsample:
in_channels = downsample.out_channels
self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
# Add a norm layer for each output
for i in out_indices:
layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
layer_name = f'norm{i}'
self.add_module(layer_name, layer)
self.output_missing_index_as_none = output_missing_index_as_none
self._freeze_stages()
self.return_stereo_feat = return_stereo_feat
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1 and self.use_abs_pos_embed:
self.absolute_pos_embed.requires_grad = False
if self.frozen_stages >= 2:
self.drop_after_pos.eval()
for i in range(0, self.frozen_stages - 1):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self):
if self.pretrained is None:
super().init_weights()
if self.use_abs_pos_embed:
trunc_normal_init(self.absolute_pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
elif isinstance(m, LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
elif isinstance(self.pretrained, str):
logger = get_root_logger()
ckpt = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
state_dict = ckpt['state_dict']
elif 'model' in ckpt:
state_dict = ckpt['model']
else:
state_dict = ckpt
if self.pretrain_style == 'official':
state_dict = swin_convert(state_dict)
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
# if list(state_dict.keys())[0].startswith('backbone.'):
# state_dict = {k[9:]: v for k, v in state_dict.items()}
# reshape absolute position embedding
if state_dict.get('absolute_pos_embed') is not None:
absolute_pos_embed = state_dict['absolute_pos_embed']
N1, L, C1 = absolute_pos_embed.size()
N2, C2, H, W = self.absolute_pos_embed.size()
if N1 != N2 or C1 != C2 or L != H * W:
logger.warning('Error in loading absolute_pos_embed, pass')
else:
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
# interpolate position bias table if needed
relative_position_bias_table_keys = [
k for k in state_dict.keys()
if 'relative_position_bias_table' in k
]
for table_key in relative_position_bias_table_keys:
table_pretrained = state_dict[table_key]
table_current = self.state_dict()[table_key]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
logger.warning(f'Error in loading {table_key}, pass')
else:
if L1 != L2:
S1 = int(L1**0.5)
S2 = int(L2**0.5)
table_pretrained_resized = resize(
table_pretrained.permute(1, 0).reshape(
1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
state_dict[table_key] = table_pretrained_resized.view(
nH2, L2).permute(1, 0).contiguous()
# load state_dict
self.load_state_dict(state_dict, False)
def forward(self, x):
x = self.patch_embed(x)
hw_shape = (self.patch_embed.DH, self.patch_embed.DW)
if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
if i == 0 and self.return_stereo_feat:
out = out.view(-1, *out_hw_shape,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(out)
out = out.view(-1, *out_hw_shape,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)
elif self.output_missing_index_as_none:
outs.append(None)
return outs
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()
\ No newline at end of file
from .bev_centerpoint_head import BEV_CenterHead, Centerness_Head
from .bev_occ_head import BEVOCCHead2D, BEVOCCHead3D, BEVOCCHead2D_V2
__all__ = ['Centerness_Head', 'BEV_CenterHead', 'BEVOCCHead2D', 'BEVOCCHead3D', 'BEVOCCHead2D_V2']
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule
from torch import nn
from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
xywhr2xyxyr)
from ...core.post_processing import nms_bev
from mmdet3d.models import builder
from mmdet3d.models.utils import clip_sigmoid
from mmdet.core import build_bbox_coder, multi_apply, reduce_mean
from mmdet3d.models.builder import HEADS, build_loss
@HEADS.register_module(force=True)
class SeparateHead(BaseModule):
"""SeparateHead for CenterHead.
Args:
in_channels (int): Input channels for conv_layer.
heads (dict): Conv information.
head_conv (int, optional): Output channels.
Default: 64.
final_kernel (int, optional): Kernel size for the last conv layer.
Default: 1.
init_bias (float, optional): Initial bias. Default: -2.19.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
"""
def __init__(self,
in_channels,
heads,
head_conv=64,
final_kernel=1,
init_bias=-2.19,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
init_cfg=None,
**kwargs):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(SeparateHead, self).__init__(init_cfg=init_cfg)
self.heads = heads
self.init_bias = init_bias
for head in self.heads:
# 该head的输出通道和卷积数量.
classes, num_conv = self.heads[head]
conv_layers = []
c_in = in_channels
for i in range(num_conv - 1):
conv_layers.append(
ConvModule(
c_in,
head_conv,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=bias,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg))
c_in = head_conv
conv_layers.append(
build_conv_layer(
conv_cfg,
head_conv,
classes,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=True))
conv_layers = nn.Sequential(*conv_layers)
self.__setattr__(head, conv_layers)
if init_cfg is None:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def init_weights(self):
"""Initialize weights."""
super().init_weights()
for head in self.heads:
if head == 'heatmap':
self.__getattr__(head)[-1].bias.data.fill_(self.init_bias)
def forward(self, x):
"""Forward function for SepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of
[B, N, H, W].
"""
ret_dict = dict()
for head in self.heads:
ret_dict[head] = self.__getattr__(head)(x)
return ret_dict
@HEADS.register_module(force=True)
class DCNSeparateHead(BaseModule):
r"""DCNSeparateHead for CenterHead.
.. code-block:: none
/-----> DCN for heatmap task -----> heatmap task.
feature
\-----> DCN for regression tasks -----> regression tasks
Args:
in_channels (int): Input channels for conv_layer.
num_cls (int): Number of classes.
heads (dict): Conv information.
dcn_config (dict): Config of dcn layer.
head_conv (int, optional): Output channels.
Default: 64.
final_kernel (int, optional): Kernel size for the last conv
layer. Default: 1.
init_bias (float, optional): Initial bias. Default: -2.19.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
""" # noqa: W605
def __init__(self,
in_channels,
num_cls,
heads,
dcn_config,
head_conv=64,
final_kernel=1,
init_bias=-2.19,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
init_cfg=None,
**kwargs):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(DCNSeparateHead, self).__init__(init_cfg=init_cfg)
if 'heatmap' in heads:
heads.pop('heatmap')
# feature adaptation with dcn
# use separate features for classification / regression
self.feature_adapt_cls = build_conv_layer(dcn_config)
self.feature_adapt_reg = build_conv_layer(dcn_config)
# heatmap prediction head
cls_head = [
ConvModule(
in_channels,
head_conv,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
bias=bias,
norm_cfg=norm_cfg),
build_conv_layer(
conv_cfg,
head_conv,
num_cls,
kernel_size=3,
stride=1,
padding=1,
bias=bias)
]
self.cls_head = nn.Sequential(*cls_head)
self.init_bias = init_bias
# other regression target
self.task_head = SeparateHead(
in_channels,
heads,
head_conv=head_conv,
final_kernel=final_kernel,
bias=bias)
if init_cfg is None:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def init_weights(self):
"""Initialize weights."""
super().init_weights()
self.cls_head[-1].bias.data.fill_(self.init_bias)
def forward(self, x):
"""Forward function for DCNSepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of
[B, N, H, W].
"""
center_feat = self.feature_adapt_cls(x)
reg_feat = self.feature_adapt_reg(x)
cls_score = self.cls_head(center_feat)
ret = self.task_head(reg_feat)
ret['heatmap'] = cls_score
return ret
@HEADS.register_module()
class BEV_CenterHead(BaseModule):
"""CenterHead for CenterPoint.
Args:
in_channels (list[int] | int, optional): Channels of the input
feature map. Default: [128].
tasks (list[dict], optional): Task information including class number
and class names. Default: None.
train_cfg (dict, optional): Train-time configs. Default: None.
test_cfg (dict, optional): Test-time configs. Default: None.
bbox_coder (dict, optional): Bbox coder configs. Default: None.
common_heads (dict, optional): Conv information for common heads.
Default: dict().
loss_cls (dict, optional): Config of classification loss function.
Default: dict(type='GaussianFocalLoss', reduction='mean').
loss_bbox (dict, optional): Config of regression loss function.
Default: dict(type='L1Loss', reduction='none').
separate_head (dict, optional): Config of separate head. Default: dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3)
share_conv_channel (int, optional): Output channels for share_conv
layer. Default: 64.
num_heatmap_convs (int, optional): Number of conv layers for heatmap
conv layer. Default: 2.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
"""
def __init__(self,
in_channels=[128],
tasks=None,
train_cfg=None,
test_cfg=None,
bbox_coder=None,
common_heads=dict(),
loss_cls=dict(type='GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(
type='L1Loss', reduction='none', loss_weight=0.25),
separate_head=dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3),
share_conv_channel=64,
num_heatmap_convs=2,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
norm_bbox=True,
init_cfg=None,
task_specific=True):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(BEV_CenterHead, self).__init__(init_cfg=init_cfg)
num_classes = [len(t['class_names']) for t in tasks] # 记录不同task(SeparateHead)负责检测的类别数.
self.class_names = [t['class_names'] for t in tasks] # 记录不同task(SeparateHead)负责检测的类别名.
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.in_channels = in_channels
self.num_classes = num_classes
self.norm_bbox = norm_bbox
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False
# a shared convolution
self.shared_conv = ConvModule(
in_channels,
share_conv_channel,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=bias)
# 每个task建立对应的head.
self.task_heads = nn.ModuleList()
for num_cls in num_classes:
# common_heads = dict(
# reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)),
heads = copy.deepcopy(common_heads)
heads.update(dict(heatmap=(num_cls, num_heatmap_convs)))
separate_head.update(
in_channels=share_conv_channel, heads=heads, num_cls=num_cls)
self.task_heads.append(builder.build_head(separate_head))
self.with_velocity = 'vel' in common_heads.keys()
self.task_specific = task_specific
def forward_single(self, x):
"""Forward function for CenterPoint.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
list[dict]: Output results for tasks.
"""
ret_dicts = []
x = self.shared_conv(x) # (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
for task in self.task_heads:
ret_dicts.append(task(x))
# ret_dicts: [dict0, dict1, ...] len = SeparateHead的数量
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
return ret_dicts
def forward(self, feats):
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
results: Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
"""
return multi_apply(self.forward_single, feats)
def _gather_feat(self, feat, ind, mask=None):
"""Gather feature map.
Given feature map and index, return indexed feature map.
Args:
feat (torch.tensor): Feature map with the shape of [B, H*W, 10].
ind (torch.Tensor): Index of the ground truth boxes with the
shape of [B, max_obj].
mask (torch.Tensor, optional): Mask of the feature map with the
shape of [B, max_obj]. Default: None.
Returns:
torch.Tensor: Feature map after gathering with the shape
of [B, max_obj, 10].
"""
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def get_targets(self, gt_bboxes_3d, gt_labels_3d):
"""Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. # List[(N_gt0, 7/9), (N_gt1, 7/9), ...]
gt_labels_3d (list[torch.Tensor]): Labels of boxes. # List[(N_gt0, ), (N_gt1, ), ...]
Returns:
Returns:
tuple[list[torch.Tensor]]: (
heatmaps: List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
anno_boxes:
inds:
masks:
)
"""
heatmaps, anno_boxes, inds, masks = multi_apply(
self.get_targets_single, gt_bboxes_3d, gt_labels_3d)
# heatmaps: # Tuple(List[(N_cls0, H, W), (N_cls1, H, W), ...], ...) len = batch_size
# anno_boxes: # Tuple(List[(max_objs, 10), (max_objs, 10), ...], ...) len = batch_size
# inds: # Tuple(List[(max_objs, ), (max_objs, ), ...], ...)
# masks: # Tuple(List[(max_objs, ), (max_objs, ), ...], ...)
# Transpose heatmaps
# List[List[(N_cls0, H, W), (N_cls0, H, W), ...], List[(N_cls1, H, W), (N_cls1, H, W), ...], ...] len = num of SeparateHead
heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps] # List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
# Transpose anno_boxes
anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes] # List[(B, max_objs, 10), (B, max_objs, 10), ...] len = num of SeparateHead
# Transpose inds
inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds] # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
# Transpose inds
masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks] # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
return heatmaps, anno_boxes, inds, masks
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
"""Generate training targets for a single sample.
Args:
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes. # (N_gt, 7/9)
gt_labels_3d (torch.Tensor): Labels of boxes. # (N_gt, )
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- heatmaps: list[torch.Tensor]: Heatmap scores. # List[(N_cls0, H, W), (N_cls1, H, W), ...]
len = num of tasks
- anno_boxes: list[torch.Tensor]: Ground truth boxes. # List[(max_objs, 10), (max_objs, 10), ...]
- inds: list[torch.Tensor]: Indexes indicating the position
of the valid boxes. # List[(max_objs, ), (max_objs, ), ...]
- masks: list[torch.Tensor]: Masks indicating which boxes
are valid. # List[(max_objs, ), (max_objs, ), ...]
"""
device = gt_labels_3d.device
gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]),
dim=1).to(device) # (N_gt, 7/9)
max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg']
grid_size = torch.tensor(self.train_cfg['grid_size']) # (Dx, Dy, Dz)
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor'] # (W, H)
# reorganize the gt_dict by tasks
task_masks = []
flag = 0
for class_name in self.class_names:
# class_name: 不同task(SeparateHead)负责检测的类别名.
task_masks.append([
torch.where(gt_labels_3d == class_name.index(i) + flag)
for i in class_name
])
flag += len(class_name)
# task_masks: List[task_mask0, task_mask1, ...] len = number of SeparateHeads
# task_mask: List[((N_gt0, ), ), ((N_gt1, ), ), ...] len = number of class
task_boxes = []
task_classes = []
flag2 = 0
for idx, mask in enumerate(task_masks):
# mask: 不同task(SeparateHead)的mask, 每个task负责检测一组不同类别的目标.
# List[((N_gt0, ), ), ((N_gt1, ), ), ...], # N_gt_task=N_gt0+N_gt1+..., 表示当前task负责检测的gt_boxes的数量.
task_box = []
task_class = []
for m in mask:
task_box.append(gt_bboxes_3d[m])
# 0 is background for each task, so we need to add 1 here.
task_class.append(gt_labels_3d[m] + 1 - flag2)
task_boxes.append(torch.cat(task_box, axis=0).to(device))
task_classes.append(torch.cat(task_class).long().to(device))
flag2 += len(mask)
# 记录不同task负责检测的gt_boxes和gt_classes:
# task_boxes: List[(N_gt_task0, 7/9), (N_gt_task1, 7/9), ...]
# task_classes: List[(N_gt_task0, ), (N_gt_task1, ), ...]
draw_gaussian = draw_heatmap_gaussian
heatmaps, anno_boxes, inds, masks = [], [], [], []
for idx, task_head in enumerate(self.task_heads):
heatmap = gt_bboxes_3d.new_zeros(
(len(self.class_names[idx]), feature_map_size[1],
feature_map_size[0])) # (N_cls, H, W) N_cls表示当前task_head负责检测的类别数目.
if self.with_velocity:
anno_box = gt_bboxes_3d.new_zeros((max_objs, 10),
dtype=torch.float32) # (max_objs, 10)
else:
anno_box = gt_bboxes_3d.new_zeros((max_objs, 8),
dtype=torch.float32)
ind = gt_labels_3d.new_zeros((max_objs, ), dtype=torch.int64) # (max_objs, )
mask = gt_bboxes_3d.new_zeros((max_objs, ), dtype=torch.uint8) # (max_objs, )
num_objs = min(task_boxes[idx].shape[0], max_objs) # 当前task_head负责检测的目标.
for k in range(num_objs):
cls_id = task_classes[idx][k] - 1 # 当前目标的cls_id, cls_id是相对task group内的.
width = task_boxes[idx][k][3] # dx
length = task_boxes[idx][k][4] # dy
# 当前目标在feature map上的width和length
width = width / voxel_size[0] / self.train_cfg[
'out_size_factor']
length = length / voxel_size[1] / self.train_cfg[
'out_size_factor']
if width > 0 and length > 0:
# 计算gaussian半径
radius = gaussian_radius(
(length, width),
min_overlap=self.train_cfg['gaussian_overlap'])
radius = max(self.train_cfg['min_radius'], int(radius))
# be really careful for the coordinate system of
# your box annotation.
x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][
1], task_boxes[idx][k][2] # 当前目标的中心坐标.
# 计算gt_box中心点在feature map中对应的位置.
coor_x = (
x - pc_range[0]
) / voxel_size[0] / self.train_cfg['out_size_factor']
coor_y = (
y - pc_range[1]
) / voxel_size[1] / self.train_cfg['out_size_factor']
center = torch.tensor([coor_x, coor_y],
dtype=torch.float32,
device=device)
center_int = center.to(torch.int32)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if not (0 <= center_int[0] < feature_map_size[0]
and 0 <= center_int[1] < feature_map_size[1]):
continue
# 根据目标中心点在feature map中对应的位置、高斯半径来设置heatmap.
draw_gaussian(heatmap[cls_id], center_int, radius)
new_idx = k
x, y = center_int[0], center_int[1]
assert (y * feature_map_size[0] + x <
feature_map_size[0] * feature_map_size[1])
# 记录正样本在feature map中的位置.
ind[new_idx] = y * feature_map_size[0] + x
mask[new_idx] = 1
# TODO: support other outdoor dataset
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
if self.norm_bbox:
box_dim = box_dim.log()
if self.with_velocity:
vx, vy = task_boxes[idx][k][7:]
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device), # tx, ty
z.unsqueeze(0), box_dim, # z, log(dx), log(dy), log(dz)
torch.sin(rot).unsqueeze(0), # sin(rot)
torch.cos(rot).unsqueeze(0), # cos(rot)
vx.unsqueeze(0), # vx
vy.unsqueeze(0) # vy
]) # [tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy]
else:
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0)
])
heatmaps.append(heatmap) # append (N_cls, H, W)
anno_boxes.append(anno_box) # append (max_objs, 10)
masks.append(mask) # append (max_objs, )
inds.append(ind) # append (max_objs, )
return heatmaps, anno_boxes, inds, masks
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
"""Loss function for CenterHead.
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. # List[(N_gt0, 7/9), (N_gt1, 7/9), ...]
gt_labels_3d (list[torch.Tensor]): Labels of boxes. # List[(N_gt0, ), (N_gt1, ), ...]
preds_dicts (dict): Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
Returns:
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
"""
heatmaps, anno_boxes, inds, masks = self.get_targets(
gt_bboxes_3d, gt_labels_3d)
# heatmaps: # List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
# anno_boxes: # List[(B, max_objs, 10), (B, max_objs, 10), ...] len = num of SeparateHead
# inds: # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
# masks: # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
loss_dict = dict()
if not self.task_specific:
loss_dict['loss'] = 0
for task_id, preds_dict in enumerate(preds_dicts):
# task_id: SeparateHead idx
# preds_dict: List[dict0, ...] len = num levels, 对于center_point len = 1
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
# heatmap focal loss
preds_dict[0]['heatmap'] = clip_sigmoid(preds_dict[0]['heatmap'])
num_pos = heatmaps[task_id].eq(1).float().sum().item()
cls_avg_factor = torch.clamp(
reduce_mean(heatmaps[task_id].new_tensor(num_pos)),
min=1).item()
loss_heatmap = self.loss_cls(
preds_dict[0]['heatmap'], # (B, cur_N_cls, H, W)
heatmaps[task_id], # (B, cur_N_cls, H, W)
avg_factor=cls_avg_factor
)
# (B, max_objs, 10) 10: (tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy)
target_box = anno_boxes[task_id]
# reconstruct the anno_box from multiple reg heads
preds_dict[0]['anno_box'] = torch.cat(
(
preds_dict[0]['reg'],
preds_dict[0]['height'],
preds_dict[0]['dim'],
preds_dict[0]['rot'],
preds_dict[0]['vel'],
),
dim=1,
) # (B, 10, H, W) 10: (tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy)
# Regression loss for dimension, offset, height, rotation
num = masks[task_id].float().sum() # 正样本的数量
ind = inds[task_id] # (B, max_objs)
pred = preds_dict[0]['anno_box'].permute(0, 2, 3, 1).contiguous() # (B, H, W, 10)
pred = pred.view(pred.size(0), -1, pred.size(3)) # (B, H*W, 10)
pred = self._gather_feat(pred, ind) # (B, max_objs, 10)
# (B, max_objs) --> (B, max_objs, 1) --> (B, max_objs, 10)
mask = masks[task_id].unsqueeze(2).expand_as(target_box).float()
num = torch.clamp(
reduce_mean(target_box.new_tensor(num)), min=1e-4).item()
isnotnan = (~torch.isnan(target_box)).float()
mask *= isnotnan # 只监督mask指定的reg预测.
code_weights = self.train_cfg['code_weights']
bbox_weights = mask * mask.new_tensor(code_weights) # 在mask基础上,设置box不同属性的权重. (B, max_objs, 10)
if self.task_specific:
name_list = ['xy', 'z', 'whl', 'yaw', 'vel']
clip_index = [0, 2, 3, 6, 8, 10]
for reg_task_id in range(len(name_list)):
pred_tmp = pred[..., clip_index[reg_task_id]:clip_index[reg_task_id + 1]] # (B, max_objs, K)
target_box_tmp = target_box[..., clip_index[reg_task_id]:clip_index[reg_task_id + 1]] # (B, max_objs, K)
bbox_weights_tmp = bbox_weights[..., clip_index[reg_task_id]:clip_index[reg_task_id + 1]] # (B, max_objs, K)
loss_bbox_tmp = self.loss_bbox(
pred_tmp,
target_box_tmp,
bbox_weights_tmp,
avg_factor=(num + 1e-4))
loss_dict[f'task{task_id}.loss_%s' %
(name_list[reg_task_id])] = loss_bbox_tmp
loss_dict[f'task{task_id}.loss_heatmap'] = loss_heatmap
else:
loss_bbox = self.loss_bbox(
pred, target_box, bbox_weights, avg_factor=num)
loss_dict['loss'] += loss_bbox
loss_dict['loss'] += loss_heatmap
return loss_dict
def get_bboxes(self, preds_dicts, img_metas, img=None, rescale=False):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
ret_list: List[p_list0, p_list1, ...]
p_list: List[(N, 9), (N, ), (N, )]
"""
rets = []
for task_id, preds_dict in enumerate(preds_dicts):
# task_id: SeparateHead idx
# preds_dict: List[dict0, ...] len = num levels, 对于center_point len = 1
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
batch_size = preds_dict[0]['heatmap'].shape[0]
batch_heatmap = preds_dict[0]['heatmap'].sigmoid() # (B, n_cls, H, W)
batch_reg = preds_dict[0]['reg'] # (B, 2, H, W)
batch_hei = preds_dict[0]['height'] # (B, 1, H, W)
if self.norm_bbox:
batch_dim = torch.exp(preds_dict[0]['dim']) # (B, 3, H, W)
else:
batch_dim = preds_dict[0]['dim']
batch_rots = preds_dict[0]['rot'][:, 0].unsqueeze(1) # (B, 1, H, W)
batch_rotc = preds_dict[0]['rot'][:, 1].unsqueeze(1) # (B, 1, H, W)
if 'vel' in preds_dict[0]:
batch_vel = preds_dict[0]['vel'] # (B, 2, H, W)
else:
batch_vel = None
temp = self.bbox_coder.decode(
batch_heatmap,
batch_rots,
batch_rotc,
batch_hei,
batch_dim,
batch_vel,
reg=batch_reg,
task_id=task_id)
# temp: List[p_dict0, p_dict1, ...] len=bs
# p_dict = {
# 'bboxes': boxes3d, # (K', 9)
# 'scores': scores, # (K', )
# 'labels': labels # (K', )
# }
batch_reg_preds = [box['bboxes'] for box in temp] # List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_preds = [box['scores'] for box in temp] # List[(K0, ), (K1, ), ...] len = bs
batch_cls_labels = [box['labels'] for box in temp] # List[(K0, ), (K1, ), ...] len = bs
nms_type = self.test_cfg.get('nms_type')
if isinstance(nms_type, list):
nms_type = nms_type[task_id]
if nms_type == 'circle':
ret_task = []
for i in range(batch_size):
boxes3d = temp[i]['bboxes']
scores = temp[i]['scores']
labels = temp[i]['labels']
centers = boxes3d[:, [0, 1]]
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
keep = torch.tensor(
circle_nms(
boxes.detach().cpu().numpy(),
self.test_cfg['min_radius'][task_id],
post_max_size=self.test_cfg['post_max_size']),
dtype=torch.long,
device=boxes.device)
boxes3d = boxes3d[keep]
scores = scores[keep]
labels = labels[keep]
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
ret_task.append(ret)
rets.append(ret_task)
else:
rets.append(
self.get_task_detections(batch_cls_preds, batch_reg_preds,
batch_cls_labels, img_metas,
task_id))
# rets: List[ret_task0, ret_task1, ...], len = num_tasks
# ret_task: List[p_dict0, p_dict1, ...], len = batch_size
# p_dict: dict{
# bboxes: (K', 9)
# scores: (K', )
# labels: (K', )
# }
# Merge branches results
num_samples = len(rets[0]) # bs
ret_list = []
# 遍历batch, 然后汇总所有task的预测.
for i in range(num_samples):
for k in rets[0][i].keys():
if k == 'bboxes':
bboxes = torch.cat([ret[i][k] for ret in rets]) # 对于bboxes, 直接拼接即可.
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = img_metas[i]['box_type_3d'](
bboxes, self.bbox_coder.code_size)
elif k == 'scores':
scores = torch.cat([ret[i][k] for ret in rets]) # 对于scores, 直接拼接即可.
elif k == 'labels':
flag = 0
for j, num_class in enumerate(self.num_classes): # 对于labels, 要进行调整, 因为预测的label是task组内的.
rets[j][i][k] += flag
flag += num_class
labels = torch.cat([ret[i][k].int() for ret in rets])
ret_list.append([bboxes, scores, labels])
# ret_list: List[p_list0, p_list1, ...]
# p_list: List[(N, 9), (N, ), (N, )]
return ret_list
def get_task_detections(self, batch_cls_preds,
batch_reg_preds, batch_cls_labels, img_metas,
task_id):
"""Rotate nms for each task.
Args:
batch_cls_preds (list[torch.Tensor]): Prediction score with the
shape of [N]. # List[(K0, ), (K1, ), ...] len = bs
batch_reg_preds (list[torch.Tensor]): Prediction bbox with the
shape of [N, 9]. # List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_labels (list[torch.Tensor]): Prediction label with the
shape of [N]. # List[(K0, ), (K1, ), ...] len = bs
img_metas (list[dict]): Meta information of each sample.
Returns:
list[dict[str: torch.Tensor]]: contains the following keys:
-bboxes (torch.Tensor): Prediction bboxes after nms with the
shape of [N, 9].
-scores (torch.Tensor): Prediction scores after nms with the
shape of [N].
-labels (torch.Tensor): Prediction labels after nms with the
shape of [N].
List[p_dict0, p_dict1, ...] len = batch_size
p_dict: dict{
bboxes: (K', 9)
scores: (K', )
labels: (K', )
}
"""
predictions_dicts = []
# 遍历不同batch的topK预测输出.
for i, (box_preds, cls_preds, cls_labels) in enumerate(
zip(batch_reg_preds, batch_cls_preds, batch_cls_labels)):
# box_preds: (K, 9)
# cls_preds: (K, )
# cls_labels: (K, )
default_val = [1.0 for _ in range(len(self.task_heads))]
factor = self.test_cfg.get('nms_rescale_factor',
default_val)[task_id]
if isinstance(factor, list):
# List[float, float, ..] len = 当前task负责预测的类别数.
# 对于box_preds, 使用其对应的factor进行缩放, 一般是放大小目标,缩小大目标.
for cid in range(len(factor)):
box_preds[cls_labels == cid, 3:6] = \
box_preds[cls_labels == cid, 3:6] * factor[cid]
else:
box_preds[:, 3:6] = box_preds[:, 3:6] * factor
# Apply NMS in birdeye view
top_labels = cls_labels.long() # (K, )
top_scores = cls_preds.squeeze(-1) if cls_preds.shape[0] > 1 \
else cls_preds # (K, )
if top_scores.shape[0] != 0:
boxes_for_nms = img_metas[i]['box_type_3d'](
box_preds[:, :], self.bbox_coder.code_size).bev # (K, 5) (x, y, dx, dy, yaw)
# the nms in 3d detection just remove overlap boxes.
if isinstance(self.test_cfg['nms_thr'], list):
nms_thresh = self.test_cfg['nms_thr'][task_id]
else:
nms_thresh = self.test_cfg['nms_thr']
selected = nms_bev(
boxes_for_nms,
top_scores,
thresh=nms_thresh,
pre_max_size=self.test_cfg['pre_max_size'],
post_max_size=self.test_cfg['post_max_size'],
xyxyr2xywhr=False,
)
else:
selected = []
# NMS后再根据factor缩放回原来的尺寸.
if isinstance(factor, list):
for cid in range(len(factor)):
box_preds[top_labels == cid, 3:6] = \
box_preds[top_labels == cid, 3:6] / factor[cid]
else:
box_preds[:, 3:6] = box_preds[:, 3:6] / factor
# if selected is not None:
selected_boxes = box_preds[selected] # (K', 9)
selected_labels = top_labels[selected] # (K', )
selected_scores = top_scores[selected] # (K', )
# finally generate predictions.
if selected_boxes.shape[0] != 0:
predictions_dict = dict(
bboxes=selected_boxes,
scores=selected_scores,
labels=selected_labels)
else:
dtype = batch_reg_preds[0].dtype
device = batch_reg_preds[0].device
predictions_dict = dict(
bboxes=torch.zeros([0, self.bbox_coder.code_size],
dtype=dtype,
device=device),
scores=torch.zeros([0], dtype=dtype, device=device),
labels=torch.zeros([0],
dtype=top_labels.dtype,
device=device))
predictions_dicts.append(predictions_dict)
return predictions_dicts
@HEADS.register_module()
class Centerness_Head(BaseModule):
"""CenterHead for CenterPoint.
Args:
in_channels (list[int] | int, optional): Channels of the input
feature map. Default: [128].
tasks (list[dict], optional): Task information including class number
and class names. Default: None.
train_cfg (dict, optional): Train-time configs. Default: None.
test_cfg (dict, optional): Test-time configs. Default: None.
bbox_coder (dict, optional): Bbox coder configs. Default: None.
common_heads (dict, optional): Conv information for common heads.
Default: dict().
loss_cls (dict, optional): Config of classification loss function.
Default: dict(type='GaussianFocalLoss', reduction='mean').
loss_bbox (dict, optional): Config of regression loss function.
Default: dict(type='L1Loss', reduction='none').
separate_head (dict, optional): Config of separate head. Default: dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3)
share_conv_channel (int, optional): Output channels for share_conv
layer. Default: 64.
num_heatmap_convs (int, optional): Number of conv layers for heatmap
conv layer. Default: 2.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
"""
def __init__(self,
in_channels=[128],
tasks=None,
train_cfg=None,
test_cfg=None,
bbox_coder=None,
common_heads=dict(),
loss_cls=dict(type='GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(
type='L1Loss', reduction='none', loss_weight=0.25),
separate_head=dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3),
share_conv_channel=64,
num_heatmap_convs=2,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
norm_bbox=True,
init_cfg=None,
task_specific=True,
task_specific_weight=[1, 1, 1, 1, 1]):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(Centerness_Head, self).__init__(init_cfg=init_cfg)
num_classes = [len(t['class_names']) for t in tasks] # 记录不同task(SeparateHead)负责检测的类别数.
self.class_names = [t['class_names'] for t in tasks] # 记录不同task(SeparateHead)负责检测的类别名.
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.in_channels = in_channels
self.num_classes = num_classes
self.norm_bbox = norm_bbox
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False
# a shared convolution
self.shared_conv = ConvModule(
in_channels,
share_conv_channel,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=bias)
# 每个task建立对应的head.
self.task_heads = nn.ModuleList()
for num_cls in num_classes:
# common_heads = dict(
# reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)),
heads = copy.deepcopy(common_heads)
heads.update(dict(heatmap=(num_cls, num_heatmap_convs)))
separate_head.update(
in_channels=share_conv_channel, heads=heads, num_cls=num_cls)
self.task_heads.append(builder.build_head(separate_head))
self.with_velocity = 'vel' in common_heads.keys()
self.task_specific = task_specific
self.task_specific_weight = task_specific_weight # [1, 1, 0, 0, 0] # 'xy', 'z', 'whl', 'yaw', 'vel'
def forward_single(self, x):
"""Forward function for CenterPoint.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
list[dict]: Output results for tasks.
"""
ret_dicts = []
x = self.shared_conv(x) # (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
for task in self.task_heads:
ret_dicts.append(task(x))
# ret_dicts: [dict0, dict1, ...] len = SeparateHead的数量
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
return ret_dicts
def forward(self, feats):
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
results: Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
"""
return multi_apply(self.forward_single, feats)
def _gather_feat(self, feat, ind, mask=None):
"""Gather feature map.
Given feature map and index, return indexed feature map.
Args:
feat (torch.tensor): Feature map with the shape of [B, H*W, 10].
ind (torch.Tensor): Index of the ground truth boxes with the
shape of [B, max_obj].
mask (torch.Tensor, optional): Mask of the feature map with the
shape of [B, max_obj]. Default: None.
Returns:
torch.Tensor: Feature map after gathering with the shape
of [B, max_obj, 10].
"""
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def get_targets(self, gt_bboxes_3d, gt_labels_3d):
"""Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. # List[(N_gt0, 7/9), (N_gt1, 7/9), ...]
gt_labels_3d (list[torch.Tensor]): Labels of boxes. # List[(N_gt0, ), (N_gt1, ), ...]
Returns:
Returns:
tuple[list[torch.Tensor]]: (
heatmaps: List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
anno_boxes:
inds:
masks:
)
"""
heatmaps, anno_boxes, inds, masks = multi_apply(
self.get_targets_single, gt_bboxes_3d, gt_labels_3d)
# heatmaps: # Tuple(List[(N_cls0, H, W), (N_cls1, H, W), ...], ...) len = batch_size
# anno_boxes: # Tuple(List[(max_objs, 10), (max_objs, 10), ...], ...) len = batch_size
# inds: # Tuple(List[(max_objs, ), (max_objs, ), ...], ...)
# masks: # Tuple(List[(max_objs, ), (max_objs, ), ...], ...)
# Transpose heatmaps
# List[List[(N_cls0, H, W), (N_cls0, H, W), ...], List[(N_cls1, H, W), (N_cls1, H, W), ...], ...] len = num of SeparateHead
heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps] # List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
# Transpose anno_boxes
anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes] # List[(B, max_objs, 10), (B, max_objs, 10), ...] len = num of SeparateHead
# Transpose inds
inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds] # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
# Transpose inds
masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks] # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
return heatmaps, anno_boxes, inds, masks
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
"""Generate training targets for a single sample.
Args:
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes. # (N_gt, 7/9)
gt_labels_3d (torch.Tensor): Labels of boxes. # (N_gt, )
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- heatmaps: list[torch.Tensor]: Heatmap scores. # List[(N_cls0, H, W), (N_cls1, H, W), ...]
len = num of tasks
- anno_boxes: list[torch.Tensor]: Ground truth boxes. # List[(max_objs, 10), (max_objs, 10), ...]
- inds: list[torch.Tensor]: Indexes indicating the position
of the valid boxes. # List[(max_objs, ), (max_objs, ), ...]
- masks: list[torch.Tensor]: Masks indicating which boxes
are valid. # List[(max_objs, ), (max_objs, ), ...]
"""
device = gt_labels_3d.device
gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]),
dim=1).to(device) # (N_gt, 7/9)
max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg']
grid_size = torch.tensor(self.train_cfg['grid_size']) # (Dx, Dy, Dz)
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor'] # (W, H)
# reorganize the gt_dict by tasks
task_masks = []
flag = 0
for class_name in self.class_names:
# class_name: 不同task(SeparateHead)负责检测的类别名.
task_masks.append([
torch.where(gt_labels_3d == class_name.index(i) + flag)
for i in class_name
])
flag += len(class_name)
# task_masks: List[task_mask0, task_mask1, ...] len = number of SeparateHeads
# task_mask: List[((N_gt0, ), ), ((N_gt1, ), ), ...] len = number of class
task_boxes = []
task_classes = []
flag2 = 0
for idx, mask in enumerate(task_masks):
# mask: 不同task(SeparateHead)的mask, 每个task负责检测一组不同类别的目标.
# List[((N_gt0, ), ), ((N_gt1, ), ), ...], # N_gt_task=N_gt0+N_gt1+..., 表示当前task负责检测的gt_boxes的数量.
task_box = []
task_class = []
for m in mask:
task_box.append(gt_bboxes_3d[m])
# 0 is background for each task, so we need to add 1 here.
task_class.append(gt_labels_3d[m] + 1 - flag2)
task_boxes.append(torch.cat(task_box, axis=0).to(device))
task_classes.append(torch.cat(task_class).long().to(device))
flag2 += len(mask)
# 记录不同task负责检测的gt_boxes和gt_classes:
# task_boxes: List[(N_gt_task0, 7/9), (N_gt_task1, 7/9), ...]
# task_classes: List[(N_gt_task0, ), (N_gt_task1, ), ...]
draw_gaussian = draw_heatmap_gaussian
heatmaps, anno_boxes, inds, masks = [], [], [], []
for idx, task_head in enumerate(self.task_heads):
heatmap = gt_bboxes_3d.new_zeros(
(len(self.class_names[idx]), feature_map_size[1],
feature_map_size[0])) # (N_cls, H, W) N_cls表示当前task_head负责检测的类别数目.
if self.with_velocity:
anno_box = gt_bboxes_3d.new_zeros((max_objs, 10),
dtype=torch.float32) # (max_objs, 10)
else:
anno_box = gt_bboxes_3d.new_zeros((max_objs, 8),
dtype=torch.float32)
ind = gt_labels_3d.new_zeros((max_objs, ), dtype=torch.int64) # (max_objs, )
mask = gt_bboxes_3d.new_zeros((max_objs, ), dtype=torch.uint8) # (max_objs, )
num_objs = min(task_boxes[idx].shape[0], max_objs) # 当前task_head负责检测的目标.
for k in range(num_objs):
cls_id = task_classes[idx][k] - 1 # 当前目标的cls_id, cls_id是相对task group内的.
width = task_boxes[idx][k][3] # dx
length = task_boxes[idx][k][4] # dy
# 当前目标在feature map上的width和length
width = width / voxel_size[0] / self.train_cfg[
'out_size_factor']
length = length / voxel_size[1] / self.train_cfg[
'out_size_factor']
if width > 0 and length > 0:
# 计算gaussian半径
radius = gaussian_radius(
(length, width),
min_overlap=self.train_cfg['gaussian_overlap'])
radius = max(self.train_cfg['min_radius'], int(radius))
# be really careful for the coordinate system of
# your box annotation.
x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][
1], task_boxes[idx][k][2] # 当前目标的中心坐标.
# 计算gt_box中心点在feature map中对应的位置.
coor_x = (
x - pc_range[0]
) / voxel_size[0] / self.train_cfg['out_size_factor']
coor_y = (
y - pc_range[1]
) / voxel_size[1] / self.train_cfg['out_size_factor']
center = torch.tensor([coor_x, coor_y],
dtype=torch.float32,
device=device)
center_int = center.to(torch.int32)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if not (0 <= center_int[0] < feature_map_size[0]
and 0 <= center_int[1] < feature_map_size[1]):
continue
# 根据目标中心点在feature map中对应的位置、高斯半径来设置heatmap.
draw_gaussian(heatmap[cls_id], center_int, radius)
new_idx = k
x, y = center_int[0], center_int[1]
assert (y * feature_map_size[0] + x <
feature_map_size[0] * feature_map_size[1])
# 记录正样本在feature map中的位置.
ind[new_idx] = y * feature_map_size[0] + x
mask[new_idx] = 1
# TODO: support other outdoor dataset
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
if self.norm_bbox:
box_dim = box_dim.log()
if self.with_velocity:
vx, vy = task_boxes[idx][k][7:]
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device), # tx, ty
z.unsqueeze(0), box_dim, # z, log(dx), log(dy), log(dz)
torch.sin(rot).unsqueeze(0), # sin(rot)
torch.cos(rot).unsqueeze(0), # cos(rot)
vx.unsqueeze(0), # vx
vy.unsqueeze(0) # vy
]) # [tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy]
else:
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0)
])
heatmaps.append(heatmap) # append (N_cls, H, W)
anno_boxes.append(anno_box) # append (max_objs, 10)
masks.append(mask) # append (max_objs, )
inds.append(ind) # append (max_objs, )
return heatmaps, anno_boxes, inds, masks
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
"""Loss function for CenterHead.
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. # List[(N_gt0, 7/9), (N_gt1, 7/9), ...]
gt_labels_3d (list[torch.Tensor]): Labels of boxes. # List[(N_gt0, ), (N_gt1, ), ...]
preds_dicts (dict): Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
Returns:
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
"""
heatmaps, anno_boxes, inds, masks = self.get_targets(
gt_bboxes_3d, gt_labels_3d)
# heatmaps: # List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
# anno_boxes: # List[(B, max_objs, 10), (B, max_objs, 10), ...] len = num of SeparateHead
# inds: # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
# masks: # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
loss_dict = dict()
if not self.task_specific:
loss_dict['loss'] = 0
for task_id, preds_dict in enumerate(preds_dicts):
# task_id: SeparateHead idx
# preds_dict: List[dict0, ...] len = num levels, 对于center_point len = 1
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
# heatmap focal loss
preds_dict[0]['heatmap'] = clip_sigmoid(preds_dict[0]['heatmap'])
num_pos = heatmaps[task_id].eq(1).float().sum().item()
cls_avg_factor = torch.clamp(
reduce_mean(heatmaps[task_id].new_tensor(num_pos)),
min=1).item()
loss_heatmap = self.loss_cls(
preds_dict[0]['heatmap'], # (B, cur_N_cls, H, W)
heatmaps[task_id], # (B, cur_N_cls, H, W)
avg_factor=cls_avg_factor
)
# (B, max_objs, 10) 10: (tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy)
target_box = anno_boxes[task_id]
# reconstruct the anno_box from multiple reg heads
preds_dict[0]['anno_box'] = torch.cat(
(
preds_dict[0]['reg'],
preds_dict[0]['height'],
preds_dict[0]['dim'],
preds_dict[0]['rot'],
preds_dict[0]['vel'],
),
dim=1,
) # (B, 10, H, W) 10: (tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy)
# Regression loss for dimension, offset, height, rotation
num = masks[task_id].float().sum() # 正样本的数量
ind = inds[task_id] # (B, max_objs)
pred = preds_dict[0]['anno_box'].permute(0, 2, 3, 1).contiguous() # (B, H, W, 10)
pred = pred.view(pred.size(0), -1, pred.size(3)) # (B, H*W, 10)
pred = self._gather_feat(pred, ind) # (B, max_objs, 10)
# (B, max_objs) --> (B, max_objs, 1) --> (B, max_objs, 10)
mask = masks[task_id].unsqueeze(2).expand_as(target_box).float()
num = torch.clamp(
reduce_mean(target_box.new_tensor(num)), min=1e-4).item()
isnotnan = (~torch.isnan(target_box)).float()
mask *= isnotnan # 只监督mask指定的reg预测.
code_weights = self.train_cfg['code_weights']
bbox_weights = mask * mask.new_tensor(code_weights) # 在mask基础上,设置box不同属性的权重. (B, max_objs, 10)
if self.task_specific:
name_list = ['xy', 'z', 'whl', 'yaw', 'vel']
clip_index = [0, 2, 3, 6, 8, 10]
for reg_task_id in range(len(name_list)):
pred_tmp = pred[..., clip_index[reg_task_id]:clip_index[reg_task_id + 1]] # (B, max_objs, K)
target_box_tmp = target_box[..., clip_index[reg_task_id]:clip_index[reg_task_id + 1]] # (B, max_objs, K)
bbox_weights_tmp = bbox_weights[..., clip_index[reg_task_id]:clip_index[reg_task_id + 1]] # (B, max_objs, K)
loss_bbox_tmp = self.loss_bbox(
pred_tmp,
target_box_tmp,
bbox_weights_tmp,
avg_factor=(num + 1e-4))
loss_dict[f'task{task_id}.loss_%s' %
(name_list[reg_task_id])] = loss_bbox_tmp * self.task_specific_weight[reg_task_id]
loss_dict[f'task{task_id}.loss_heatmap'] = loss_heatmap
else:
loss_bbox = self.loss_bbox(
pred, target_box, bbox_weights, avg_factor=num)
loss_dict['loss'] += loss_bbox
loss_dict['loss'] += loss_heatmap
return loss_dict
def get_bboxes(self, preds_dicts, img_metas, img=None, rescale=False):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
ret_list: List[p_list0, p_list1, ...]
p_list: List[(N, 9), (N, ), (N, )]
"""
rets = []
for task_id, preds_dict in enumerate(preds_dicts):
# task_id: SeparateHead idx
# preds_dict: List[dict0, ...] len = num levels, 对于center_point len = 1
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
batch_size = preds_dict[0]['heatmap'].shape[0]
batch_heatmap = preds_dict[0]['heatmap'].sigmoid() # (B, n_cls, H, W)
batch_reg = preds_dict[0]['reg'] # (B, 2, H, W)
batch_hei = preds_dict[0]['height'] # (B, 1, H, W)
if self.norm_bbox:
batch_dim = torch.exp(preds_dict[0]['dim']) # (B, 3, H, W)
else:
batch_dim = preds_dict[0]['dim']
batch_rots = preds_dict[0]['rot'][:, 0].unsqueeze(1) # (B, 1, H, W)
batch_rotc = preds_dict[0]['rot'][:, 1].unsqueeze(1) # (B, 1, H, W)
if 'vel' in preds_dict[0]:
batch_vel = preds_dict[0]['vel'] # (B, 2, H, W)
else:
batch_vel = None
temp = self.bbox_coder.decode(
batch_heatmap,
batch_rots,
batch_rotc,
batch_hei,
batch_dim,
batch_vel,
reg=batch_reg,
task_id=task_id)
# temp: List[p_dict0, p_dict1, ...] len=bs
# p_dict = {
# 'bboxes': boxes3d, # (K', 9)
# 'scores': scores, # (K', )
# 'labels': labels # (K', )
# }
batch_reg_preds = [box['bboxes'] for box in temp] # List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_preds = [box['scores'] for box in temp] # List[(K0, ), (K1, ), ...] len = bs
batch_cls_labels = [box['labels'] for box in temp] # List[(K0, ), (K1, ), ...] len = bs
nms_type = self.test_cfg.get('nms_type')
if isinstance(nms_type, list):
nms_type = nms_type[task_id]
if nms_type == 'circle':
ret_task = []
for i in range(batch_size):
boxes3d = temp[i]['bboxes']
scores = temp[i]['scores']
labels = temp[i]['labels']
centers = boxes3d[:, [0, 1]]
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
keep = torch.tensor(
circle_nms(
boxes.detach().cpu().numpy(),
self.test_cfg['min_radius'][task_id],
post_max_size=self.test_cfg['post_max_size']),
dtype=torch.long,
device=boxes.device)
boxes3d = boxes3d[keep]
scores = scores[keep]
labels = labels[keep]
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
ret_task.append(ret)
rets.append(ret_task)
else:
rets.append(
self.get_task_detections(batch_cls_preds, batch_reg_preds,
batch_cls_labels, img_metas,
task_id))
# rets: List[ret_task0, ret_task1, ...], len = num_tasks
# ret_task: List[p_dict0, p_dict1, ...], len = batch_size
# p_dict: dict{
# bboxes: (K', 9)
# scores: (K', )
# labels: (K', )
# }
# Merge branches results
num_samples = len(rets[0]) # bs
ret_list = []
# 遍历batch, 然后汇总所有task的预测.
for i in range(num_samples):
for k in rets[0][i].keys():
if k == 'bboxes':
bboxes = torch.cat([ret[i][k] for ret in rets]) # 对于bboxes, 直接拼接即可.
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = img_metas[i]['box_type_3d'](
bboxes, self.bbox_coder.code_size)
elif k == 'scores':
scores = torch.cat([ret[i][k] for ret in rets]) # 对于scores, 直接拼接即可.
elif k == 'labels':
flag = 0
for j, num_class in enumerate(self.num_classes): # 对于labels, 要进行调整, 因为预测的label是task组内的.
rets[j][i][k] += flag
flag += num_class
labels = torch.cat([ret[i][k].int() for ret in rets])
ret_list.append([bboxes, scores, labels])
# ret_list: List[p_list0, p_list1, ...]
# p_list: List[(N, 9), (N, ), (N, )]
return ret_list
def _nms(self, heat, kernel=3):
pad = (kernel - 1) // 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=pad)
keep = (hmax == heat).float()
return heat * keep
def get_centers(self, preds_dicts, img_metas, img=None, rescale=False):
rets = []
for task_id, preds_dict in enumerate(preds_dicts):
batch_size = preds_dict[0]['heatmap'].shape[0]
batch_heatmap = preds_dict[0]['heatmap'].sigmoid() # (B, n_cls, H, W)
batch_reg = preds_dict[0]['reg'] # (B, 2, H, W)
batch_hei = preds_dict[0]['height'] # (B, 1, H, W)
batch_heatmap = self._nms(batch_heatmap)
temp = self.bbox_coder.center_decode(
batch_heatmap,
batch_hei,
reg=batch_reg,
task_id=task_id)
batch_reg_preds = [box['centers'] for box in temp] # List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_preds = [box['scores'] for box in temp] # List[(K0, ), (K1, ), ...] len = bs
batch_cls_labels = [box['labels'] for box in temp] # List[(K0, ), (K1, ), ...] len = bs
ret_list = [batch_reg_preds, batch_cls_preds, batch_cls_labels]
return ret_list
def get_task_detections(self, batch_cls_preds,
batch_reg_preds, batch_cls_labels, img_metas,
task_id):
"""Rotate nms for each task.
Args:
batch_cls_preds (list[torch.Tensor]): Prediction score with the
shape of [N]. # List[(K0, ), (K1, ), ...] len = bs
batch_reg_preds (list[torch.Tensor]): Prediction bbox with the
shape of [N, 9]. # List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_labels (list[torch.Tensor]): Prediction label with the
shape of [N]. # List[(K0, ), (K1, ), ...] len = bs
img_metas (list[dict]): Meta information of each sample.
Returns:
list[dict[str: torch.Tensor]]: contains the following keys:
-bboxes (torch.Tensor): Prediction bboxes after nms with the
shape of [N, 9].
-scores (torch.Tensor): Prediction scores after nms with the
shape of [N].
-labels (torch.Tensor): Prediction labels after nms with the
shape of [N].
List[p_dict0, p_dict1, ...] len = batch_size
p_dict: dict{
bboxes: (K', 9)
scores: (K', )
labels: (K', )
}
"""
predictions_dicts = []
# 遍历不同batch的topK预测输出.
for i, (box_preds, cls_preds, cls_labels) in enumerate(
zip(batch_reg_preds, batch_cls_preds, batch_cls_labels)):
# box_preds: (K, 9)
# cls_preds: (K, )
# cls_labels: (K, )
default_val = [1.0 for _ in range(len(self.task_heads))]
factor = self.test_cfg.get('nms_rescale_factor',
default_val)[task_id]
if isinstance(factor, list):
# List[float, float, ..] len = 当前task负责预测的类别数.
# 对于box_preds, 使用其对应的factor进行缩放, 一般是放大小目标,缩小大目标.
for cid in range(len(factor)):
box_preds[cls_labels == cid, 3:6] = \
box_preds[cls_labels == cid, 3:6] * factor[cid]
else:
box_preds[:, 3:6] = box_preds[:, 3:6] * factor
# Apply NMS in birdeye view
top_labels = cls_labels.long() # (K, )
top_scores = cls_preds.squeeze(-1) if cls_preds.shape[0] > 1 \
else cls_preds # (K, )
if top_scores.shape[0] != 0:
boxes_for_nms = img_metas[i]['box_type_3d'](
box_preds[:, :], self.bbox_coder.code_size).bev # (K, 5) (x, y, dx, dy, yaw)
# the nms in 3d detection just remove overlap boxes.
if isinstance(self.test_cfg['nms_thr'], list):
nms_thresh = self.test_cfg['nms_thr'][task_id]
else:
nms_thresh = self.test_cfg['nms_thr']
selected = nms_bev(
boxes_for_nms,
top_scores,
thresh=nms_thresh,
pre_max_size=self.test_cfg['pre_max_size'],
post_max_size=self.test_cfg['post_max_size'],
xyxyr2xywhr=False,
)
else:
selected = []
# NMS后再根据factor缩放回原来的尺寸.
if isinstance(factor, list):
for cid in range(len(factor)):
box_preds[top_labels == cid, 3:6] = \
box_preds[top_labels == cid, 3:6] / factor[cid]
else:
box_preds[:, 3:6] = box_preds[:, 3:6] / factor
# if selected is not None:
selected_boxes = box_preds[selected] # (K', 9)
selected_labels = top_labels[selected] # (K', )
selected_scores = top_scores[selected] # (K', )
# finally generate predictions.
if selected_boxes.shape[0] != 0:
predictions_dict = dict(
bboxes=selected_boxes,
scores=selected_scores,
labels=selected_labels)
else:
dtype = batch_reg_preds[0].dtype
device = batch_reg_preds[0].device
predictions_dict = dict(
bboxes=torch.zeros([0, self.bbox_coder.code_size],
dtype=dtype,
device=device),
scores=torch.zeros([0], dtype=dtype, device=device),
labels=torch.zeros([0],
dtype=top_labels.dtype,
device=device))
predictions_dicts.append(predictions_dict)
return predictions_dicts
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
import numpy as np
from mmdet3d.models.builder import HEADS, build_loss
from ..losses.semkitti_loss import sem_scal_loss, geo_scal_loss
from ..losses.lovasz_softmax import lovasz_softmax
nusc_class_frequencies = np.array([
944004,
1897170,
152386,
2391677,
16957802,
724139,
189027,
2074468,
413451,
2384460,
5916653,
175883646,
4275424,
51393615,
61411620,
105975596,
116424404,
1892500630
])
@HEADS.register_module()
class BEVOCCHead3D(BaseModule):
def __init__(self,
in_dim=32,
out_dim=32,
use_mask=True,
num_classes=18,
use_predicter=True,
class_balance=False,
loss_occ=None
):
super(BEVOCCHead3D, self).__init__()
self.out_dim = 32
out_channels = out_dim if use_predicter else num_classes
self.final_conv = ConvModule(
in_dim,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
conv_cfg=dict(type='Conv3d')
)
self.use_predicter = use_predicter
if use_predicter:
self.predicter = nn.Sequential(
nn.Linear(self.out_dim, self.out_dim*2),
nn.Softplus(),
nn.Linear(self.out_dim*2, num_classes),
)
self.num_classes = num_classes
self.use_mask = use_mask
self.class_balance = class_balance
if self.class_balance:
class_weights = torch.from_numpy(1 / np.log(nusc_class_frequencies[:num_classes] + 0.001))
self.cls_weights = class_weights
loss_occ['class_weight'] = class_weights
self.loss_occ = build_loss(loss_occ)
def forward(self, img_feats):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx)
Returns:
"""
# (B, C, Dz, Dy, Dx) --> (B, C, Dz, Dy, Dx) --> (B, Dx, Dy, Dz, C)
occ_pred = self.final_conv(img_feats).permute(0, 4, 3, 2, 1)
if self.use_predicter:
# (B, Dx, Dy, Dz, C) --> (B, Dx, Dy, Dz, 2*C) --> (B, Dx, Dy, Dz, n_cls)
occ_pred = self.predicter(occ_pred)
return occ_pred
def loss(self, occ_pred, voxel_semantics, mask_camera):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, n_cls)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
loss = dict()
voxel_semantics = voxel_semantics.long()
if self.use_mask:
mask_camera = mask_camera.to(torch.int32) # (B, Dx, Dy, Dz)
# (B, Dx, Dy, Dz) --> (B*Dx*Dy*Dz, )
voxel_semantics = voxel_semantics.reshape(-1)
# (B, Dx, Dy, Dz, n_cls) --> (B*Dx*Dy*Dz, n_cls)
preds = occ_pred.reshape(-1, self.num_classes)
# (B, Dx, Dy, Dz) --> (B*Dx*Dy*Dz, )
mask_camera = mask_camera.reshape(-1)
if self.class_balance:
valid_voxels = voxel_semantics[mask_camera.bool()]
num_total_samples = 0
for i in range(self.num_classes):
num_total_samples += (valid_voxels == i).sum() * self.cls_weights[i]
else:
num_total_samples = mask_camera.sum()
loss_occ = self.loss_occ(
preds, # (B*Dx*Dy*Dz, n_cls)
voxel_semantics, # (B*Dx*Dy*Dz, )
mask_camera, # (B*Dx*Dy*Dz, )
avg_factor=num_total_samples
)
else:
voxel_semantics = voxel_semantics.reshape(-1)
preds = occ_pred.reshape(-1, self.num_classes)
if self.class_balance:
num_total_samples = 0
for i in range(self.num_classes):
num_total_samples += (voxel_semantics == i).sum() * self.cls_weights[i]
else:
num_total_samples = len(voxel_semantics)
loss_occ = self.loss_occ(
preds,
voxel_semantics,
avg_factor=num_total_samples
)
loss['loss_occ'] = loss_occ
return loss
def get_occ(self, occ_pred, img_metas=None):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, C)
img_metas:
Returns:
List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
occ_score = occ_pred.softmax(-1) # (B, Dx, Dy, Dz, C)
occ_res = occ_score.argmax(-1) # (B, Dx, Dy, Dz)
occ_res = occ_res.cpu().numpy().astype(np.uint8) # (B, Dx, Dy, Dz)
return list(occ_res)
@HEADS.register_module()
class BEVOCCHead2D(BaseModule):
def __init__(self,
in_dim=256,
out_dim=256,
Dz=16,
use_mask=True,
num_classes=18,
use_predicter=True,
class_balance=False,
loss_occ=None,
):
super(BEVOCCHead2D, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.Dz = Dz
out_channels = out_dim if use_predicter else num_classes * Dz
self.final_conv = ConvModule(
self.in_dim,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
conv_cfg=dict(type='Conv2d')
)
self.use_predicter = use_predicter
if use_predicter:
self.predicter = nn.Sequential(
nn.Linear(self.out_dim, self.out_dim * 2),
nn.Softplus(),
nn.Linear(self.out_dim * 2, num_classes * Dz),
)
self.use_mask = use_mask
self.num_classes = num_classes
self.class_balance = class_balance
if self.class_balance:
class_weights = torch.from_numpy(1 / np.log(nusc_class_frequencies[:num_classes] + 0.001))
self.cls_weights = class_weights
loss_occ['class_weight'] = class_weights # ce loss
self.loss_occ = build_loss(loss_occ)
#@torch.compiler.disable
def forward(self, img_feats):
"""
Args:
img_feats: (B, C, Dy, Dx)
Returns:
"""
# (B, C, Dy, Dx) --> (B, C, Dy, Dx) --> (B, Dx, Dy, C)
occ_pred = self.final_conv(img_feats).permute(0, 3, 2, 1)
bs, Dx, Dy = occ_pred.shape[:3]
if self.use_predicter:
# (B, Dx, Dy, C) --> (B, Dx, Dy, 2*C) --> (B, Dx, Dy, Dz*n_cls)
occ_pred = self.predicter(occ_pred)
occ_pred = occ_pred.view(bs, Dx, Dy, self.Dz, self.num_classes)
return occ_pred
def loss(self, occ_pred, voxel_semantics, mask_camera):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, n_cls)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
loss = dict()
voxel_semantics = voxel_semantics.long()
if self.use_mask:
mask_camera = mask_camera.to(torch.int32) # (B, Dx, Dy, Dz)
# (B, Dx, Dy, Dz) --> (B*Dx*Dy*Dz, )
voxel_semantics = voxel_semantics.reshape(-1)
# (B, Dx, Dy, Dz, n_cls) --> (B*Dx*Dy*Dz, n_cls)
preds = occ_pred.reshape(-1, self.num_classes)
# (B, Dx, Dy, Dz) --> (B*Dx*Dy*Dz, )
mask_camera = mask_camera.reshape(-1)
if self.class_balance:
valid_voxels = voxel_semantics[mask_camera.bool()]
num_total_samples = 0
for i in range(self.num_classes):
num_total_samples += (valid_voxels == i).sum() * self.cls_weights[i]
else:
num_total_samples = mask_camera.sum()
loss_occ = self.loss_occ(
preds, # (B*Dx*Dy*Dz, n_cls)
voxel_semantics, # (B*Dx*Dy*Dz, )
mask_camera, # (B*Dx*Dy*Dz, )
avg_factor=num_total_samples
)
loss['loss_occ'] = loss_occ
else:
voxel_semantics = voxel_semantics.reshape(-1)
preds = occ_pred.reshape(-1, self.num_classes)
if self.class_balance:
num_total_samples = 0
for i in range(self.num_classes):
num_total_samples += (voxel_semantics == i).sum() * self.cls_weights[i]
else:
num_total_samples = len(voxel_semantics)
loss_occ = self.loss_occ(
preds,
voxel_semantics,
avg_factor=num_total_samples
)
loss['loss_occ'] = loss_occ
return loss
def get_occ(self, occ_pred, img_metas=None):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, C)
img_metas:
Returns:
List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
occ_score = occ_pred.softmax(-1) # (B, Dx, Dy, Dz, C)
occ_res = occ_score.argmax(-1) # (B, Dx, Dy, Dz)
occ_res = occ_res.cpu().numpy().astype(np.uint8) # (B, Dx, Dy, Dz)
return list(occ_res)
@HEADS.register_module()
class BEVOCCHead2D_V2(BaseModule): # Use stronger loss setting
def __init__(self,
in_dim=256,
out_dim=256,
Dz=16,
use_mask=True,
num_classes=18,
use_predicter=True,
class_balance=False,
loss_occ=None,
):
super(BEVOCCHead2D_V2, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.Dz = Dz
# voxel-level prediction
self.occ_convs = nn.ModuleList()
self.final_conv = ConvModule(
in_dim,
self.out_dim,
kernel_size=3,
stride=1,
padding=1,
bias=True,
conv_cfg=dict(type='Conv2d')
)
self.use_predicter = use_predicter
if use_predicter:
self.predicter = nn.Sequential(
nn.Linear(self.out_dim, self.out_dim * 2),
nn.Softplus(),
nn.Linear(self.out_dim * 2, num_classes * Dz),
)
self.use_mask = use_mask
self.num_classes = num_classes
self.class_balance = class_balance
if self.class_balance:
class_weights = torch.from_numpy(1 / np.log(nusc_class_frequencies[:num_classes] + 0.001))
self.cls_weights = class_weights
self.loss_occ = build_loss(loss_occ)
def forward(self, img_feats):
"""
Args:
img_feats: (B, C, Dy=200, Dx=200)
img_feats: [(B, C, 100, 100), (B, C, 50, 50), (B, C, 25, 25)] if ms
Returns:
"""
# (B, C, Dy, Dx) --> (B, C, Dy, Dx) --> (B, Dx, Dy, C)
occ_pred = self.final_conv(img_feats).permute(0, 3, 2, 1)
bs, Dx, Dy = occ_pred.shape[:3]
if self.use_predicter:
# (B, Dx, Dy, C) --> (B, Dx, Dy, 2*C) --> (B, Dx, Dy, Dz*n_cls)
occ_pred = self.predicter(occ_pred)
occ_pred = occ_pred.view(bs, Dx, Dy, self.Dz, self.num_classes)
return occ_pred
def loss(self, occ_pred, voxel_semantics, mask_camera):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, n_cls)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
loss = dict()
voxel_semantics = voxel_semantics.long() # (B, Dx, Dy, Dz)
preds = occ_pred.permute(0, 4, 1, 2, 3).contiguous() # (B, n_cls, Dx, Dy, Dz)
loss_occ = self.loss_occ(
preds,
voxel_semantics,
weight=self.cls_weights.to(preds),
) * 100.0
loss['loss_occ'] = loss_occ
loss['loss_voxel_sem_scal'] = sem_scal_loss(preds, voxel_semantics)
loss['loss_voxel_geo_scal'] = geo_scal_loss(preds, voxel_semantics, non_empty_idx=17)
loss['loss_voxel_lovasz'] = lovasz_softmax(torch.softmax(preds, dim=1), voxel_semantics)
return loss
def get_occ(self, occ_pred, img_metas=None):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, C)
img_metas:
Returns:
List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
occ_score = occ_pred.softmax(-1) # (B, Dx, Dy, Dz, C)
occ_res = occ_score.argmax(-1) # (B, Dx, Dy, Dz)
occ_res = occ_res.cpu().numpy().astype(np.uint8) # (B, Dx, Dy, Dz)
return list(occ_res)
def get_occ_gpu(self, occ_pred, img_metas=None):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, C)
img_metas:
Returns:
List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
occ_score = occ_pred.softmax(-1) # (B, Dx, Dy, Dz, C)
occ_res = occ_score.argmax(-1).int() # (B, Dx, Dy, Dz)
return list(occ_res)
from .bevdet import BEVDet
from .bevdepth import BEVDepth
from .bevdet4d import BEVDet4D
from .bevdepth4d import BEVDepth4D
from .bevstereo4d import BEVStereo4D
from .bevdet_occ import BEVDetOCC, BEVDepthOCC, BEVDepth4DOCC, BEVStereo4DOCC, BEVDepth4DPano, BEVDepthPano, BEVDepthPanoTRT
__all__ = ['BEVDet', 'BEVDepth', 'BEVDet4D', 'BEVDepth4D', 'BEVStereo4D', 'BEVDetOCC', 'BEVDepthOCC',
'BEVDepth4DOCC', 'BEVStereo4DOCC', 'BEVDepthPano', 'BEVDepth4DPano', 'BEVDepthPanoTRT']
\ No newline at end of file
# Copyright (c) Phigent Robotics. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet3d.models import DETECTORS
from .bevdet import BEVDet
from mmdet3d.models import builder
@DETECTORS.register_module()
class BEVDepth(BEVDet):
def __init__(self, img_backbone, img_neck, img_view_transformer, img_bev_encoder_backbone, img_bev_encoder_neck,
pts_bbox_head=None, **kwargs):
super(BEVDepth, self).__init__(img_backbone=img_backbone,
img_neck=img_neck,
img_view_transformer=img_view_transformer,
img_bev_encoder_backbone=img_bev_encoder_backbone,
img_bev_encoder_neck=img_bev_encoder_neck,
pts_bbox_head=pts_bbox_head
)
def image_encoder(self, img, stereo=False):
"""
Args:
img: (B, N, 3, H, W)
stereo: bool
Returns:
x: (B, N, C, fH, fW)
stereo_feat: (B*N, C_stereo, fH_stereo, fW_stereo) / None
"""
imgs = img
B, N, C, imH, imW = imgs.shape
imgs = imgs.view(B * N, C, imH, imW)
x = self.img_backbone(imgs)
stereo_feat = None
if stereo:
stereo_feat = x[0]
x = x[1:]
if self.with_img_neck:
x = self.img_neck(x)
if type(x) in [list, tuple]:
x = x[0]
_, output_dim, ouput_H, output_W = x.shape
x = x.view(B, N, output_dim, ouput_H, output_W)
return x, stereo_feat
@force_fp32()
def bev_encoder(self, x):
"""
Args:
x: (B, C, Dy, Dx)
Returns:
x: (B, C', 2*Dy, 2*Dx)
"""
x = self.img_bev_encoder_backbone(x)
x = self.img_bev_encoder_neck(x)
if type(x) in [list, tuple]:
x = x[0]
return x
def prepare_inputs(self, inputs):
# split the inputs into each frame
assert len(inputs) == 7
B, N, C, H, W = inputs[0].shape
imgs, sensor2egos, ego2globals, intrins, post_rots, post_trans, bda = \
inputs
sensor2egos = sensor2egos.view(B, N, 4, 4)
ego2globals = ego2globals.view(B, N, 4, 4)
# calculate the transformation from adj sensor to key ego
keyego2global = ego2globals[:, 0, ...].unsqueeze(1) # (B, 1, 4, 4)
global2keyego = torch.inverse(keyego2global.double()) # (B, 1, 4, 4)
sensor2keyegos = \
global2keyego @ ego2globals.double() @ sensor2egos.double() # (B, N_views, 4, 4)
sensor2keyegos = sensor2keyegos.float()
return [imgs, sensor2keyegos, ego2globals, intrins,
post_rots, post_trans, bda]
def extract_img_feat(self, img_inputs, img_metas, **kwargs):
""" Extract features of images.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
Returns:
x: [(B, C', H', W'), ]
depth: (B*N, D, fH, fW)
"""
imgs, sensor2keyegos, ego2globals, intrins, post_rots, post_trans, bda = self.prepare_inputs(img_inputs)
x, _ = self.image_encoder(imgs) # x: (B, N, C, fH, fW)
mlp_input = self.img_view_transformer.get_mlp_input(
sensor2keyegos, ego2globals, intrins, post_rots, post_trans, bda) # (B, N_views, 27)
x, depth = self.img_view_transformer([x, sensor2keyegos, ego2globals, intrins, post_rots,
post_trans, bda, mlp_input])
# x: (B, C, Dy, Dx)
# depth: (B*N, D, fH, fW)
x = self.bev_encoder(x)
return [x], depth
def extract_feat(self, points, img_inputs, img_metas, **kwargs):
"""Extract features from images and points."""
"""
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
"""
img_feats, depth = self.extract_img_feat(img_inputs, img_metas, **kwargs)
pts_feats = None
return img_feats, pts_feats, depth
def forward_train(self,
points=None,
img_inputs=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
img_metas=None,
gt_bboxes=None,
gt_labels=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W) # N_views = 6 * (N_history + 1)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses = dict(loss_depth=loss_depth)
losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore)
losses.update(losses_pts)
return losses
def forward_test(self,
points=None,
img_inputs=None,
img_metas=None,
**kwargs):
"""
Args:
points (list[torch.Tensor]): the outer list indicates test-time
augmentations and inner torch.Tensor should have a shape NxC,
which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
img (list[torch.Tensor], optional): the outer
list indicates test-time augmentations and inner
torch.Tensor should have a shape NxCxHxW, which contains
all images in the batch. Defaults to None.
"""
for var, name in [(img_inputs, 'img_inputs'),
(img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(img_inputs)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(img_inputs), len(img_metas)))
if not isinstance(img_inputs[0][0], list):
img_inputs = [img_inputs] if img_inputs is None else img_inputs
points = [points] if points is None else points
return self.simple_test(points[0], img_metas[0], img_inputs[0],
**kwargs)
else:
return self.aug_test(None, img_metas[0], img_inputs[0], **kwargs)
def aug_test(self, points, img_metas, img=None, rescale=False):
"""Test function without augmentaiton."""
assert False
def simple_test(self,
points,
img_metas,
img_inputs=None,
rescale=False,
**kwargs):
"""Test function without augmentaiton.
Returns:
bbox_list: List[dict0, dict1, ...] len = bs
dict: {
'pts_bbox': dict: {
'boxes_3d': (N, 9)
'scores_3d': (N, )
'labels_3d': (N, )
}
}
"""
img_feats, _, _ = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
bbox_list = [dict() for _ in range(len(img_metas))]
bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale)
# bbox_pts: List[dict0, dict1, ...], len = batch_size
# dict: {
# 'boxes_3d': (N, 9)
# 'scores_3d': (N, )
# 'labels_3d': (N, )
# }
for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
result_dict['pts_bbox'] = pts_bbox
return bbox_list
def forward_dummy(self,
points=None,
img_metas=None,
img_inputs=None,
**kwargs):
img_feats, _, _ = self.extract_feat(
points, img=img_inputs, img_metas=img_metas, **kwargs)
assert self.with_pts_bbox
outs = self.pts_bbox_head(img_feats)
return outs
\ No newline at end of file
# Copyright (c) Phigent Robotics. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet3d.models import DETECTORS
from mmdet3d.models import builder
from .bevdet4d import BEVDet4D
@DETECTORS.register_module()
class BEVDepth4D(BEVDet4D):
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses = dict(loss_depth=loss_depth)
losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore)
losses.update(losses_pts)
return losses
\ No newline at end of file
# Copyright (c) Phigent Robotics. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet3d.models import DETECTORS
from mmdet3d.models import CenterPoint
from mmdet3d.models import builder
@DETECTORS.register_module()
class BEVDet(CenterPoint):
def __init__(self, img_backbone, img_neck, img_view_transformer, img_bev_encoder_backbone, img_bev_encoder_neck,
pts_bbox_head=None, **kwargs):
super(BEVDet, self).__init__(img_backbone=img_backbone, img_neck=img_neck, pts_bbox_head=pts_bbox_head,
**kwargs)
self.img_view_transformer = builder.build_neck(img_view_transformer)
self.img_bev_encoder_backbone = builder.build_backbone(img_bev_encoder_backbone)
self.img_bev_encoder_neck = builder.build_neck(img_bev_encoder_neck)
def image_encoder(self, img, stereo=False):
"""
Args:
img: (B, N, 3, H, W)
stereo: bool
Returns:
x: (B, N, C, fH, fW)
stereo_feat: (B*N, C_stereo, fH_stereo, fW_stereo) / None
"""
imgs = img
B, N, C, imH, imW = imgs.shape
imgs = imgs.view(B * N, C, imH, imW)
#imgs = imgs.to(memory_format=torch.channels_last)
x = self.img_backbone(imgs)
stereo_feat = None
if stereo:
stereo_feat = x[0]
x = x[1:]
if self.with_img_neck:
x = self.img_neck(x)
if type(x) in [list, tuple]:
x = x[0]
_, output_dim, ouput_H, output_W = x.shape
x = x.view(B, N, output_dim, ouput_H, output_W)
return x, stereo_feat
@force_fp32()
def bev_encoder(self, x):
"""
Args:
x: (B, C, Dy, Dx)
Returns:
x: (B, C', 2*Dy, 2*Dx)
"""
x = self.img_bev_encoder_backbone(x)
x = self.img_bev_encoder_neck(x)
if type(x) in [list, tuple]:
x = x[0]
return x
def prepare_inputs(self, inputs):
# split the inputs into each frame
assert len(inputs) == 7
B, N, C, H, W = inputs[0].shape
imgs, sensor2egos, ego2globals, intrins, post_rots, post_trans, bda = \
inputs
sensor2egos = sensor2egos.view(B, N, 4, 4)
ego2globals = ego2globals.view(B, N, 4, 4)
# calculate the transformation from adj sensor to key ego
keyego2global = ego2globals[:, 0, ...].unsqueeze(1) # (B, 1, 4, 4)
global2keyego = torch.inverse(keyego2global.double()) # (B, 1, 4, 4)
sensor2keyegos = \
global2keyego @ ego2globals.double() @ sensor2egos.double() # (B, N_views, 4, 4)
sensor2keyegos = sensor2keyegos.float()
return [imgs, sensor2keyegos, ego2globals, intrins,
post_rots, post_trans, bda]
def extract_img_feat(self, img_inputs, img_metas, **kwargs):
""" Extract features of images.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
Returns:
x: [(B, C', H', W'), ]
depth: (B*N, D, fH, fW)
"""
img_inputs = self.prepare_inputs(img_inputs)
x, _ = self.image_encoder(img_inputs[0]) # x: (B, N, C, fH, fW)
x, depth = self.img_view_transformer([x] + img_inputs[1:7])
# x: (B, C, Dy, Dx)
# depth: (B*N, D, fH, fW)
x = self.bev_encoder(x)
return [x], depth
@torch.compile
def extract_feat(self, points, img_inputs, img_metas, **kwargs):
"""Extract features from images and points."""
"""
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
"""
img_feats, depth = self.extract_img_feat(img_inputs, img_metas, **kwargs)
pts_feats = None
return img_feats, pts_feats, depth
def forward_train(self,
points=None,
img_inputs=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
img_metas=None,
gt_bboxes=None,
gt_labels=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_inputs:
imgs: (B, N_views, 3, H, W) # N_views = 6 * (N_history + 1)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
img_feats, pts_feats, _ = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
losses = dict()
losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore)
losses.update(losses_pts)
return losses
def forward_test(self,
points=None,
img_inputs=None,
img_metas=None,
**kwargs):
"""
Args:
points (list[torch.Tensor]): the outer list indicates test-time
augmentations and inner torch.Tensor should have a shape NxC,
which contains all points in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch
img (list[torch.Tensor], optional): the outer
list indicates test-time augmentations and inner
torch.Tensor should have a shape NxCxHxW, which contains
all images in the batch. Defaults to None.
"""
for var, name in [(img_inputs, 'img_inputs'),
(img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
name, type(var)))
num_augs = len(img_inputs)
if num_augs != len(img_metas):
raise ValueError(
'num of augmentations ({}) != num of image meta ({})'.format(
len(img_inputs), len(img_metas)))
if not isinstance(img_inputs[0][0], list):
img_inputs = [img_inputs] if img_inputs is None else img_inputs
points = [points] if points is None else points
return self.simple_test(points[0], img_metas[0], img_inputs[0],
**kwargs)
else:
return self.aug_test(None, img_metas[0], img_inputs[0], **kwargs)
def aug_test(self, points, img_metas, img=None, rescale=False):
"""Test function without augmentaiton."""
assert False
def simple_test(self,
points,
img_metas,
img_inputs=None,
rescale=False,
**kwargs):
"""Test function without augmentaiton.
Returns:
bbox_list: List[dict0, dict1, ...] len = bs
dict: {
'pts_bbox': dict: {
'boxes_3d': (N, 9)
'scores_3d': (N, )
'labels_3d': (N, )
}
}
"""
img_feats, _, _ = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
bbox_list = [dict() for _ in range(len(img_metas))]
bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale)
# bbox_pts: List[dict0, dict1, ...], len = batch_size
# dict: {
# 'boxes_3d': (N, 9)
# 'scores_3d': (N, )
# 'labels_3d': (N, )
# }
for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
result_dict['pts_bbox'] = pts_bbox
return bbox_list
def forward_dummy(self,
points=None,
img_metas=None,
img_inputs=None,
**kwargs):
img_feats, _, _ = self.extract_feat(
points, img=img_inputs, img_metas=img_metas, **kwargs)
assert self.with_pts_bbox
outs = self.pts_bbox_head(img_feats)
return outs
# Copyright (c) Phigent Robotics. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet3d.models import DETECTORS
from mmdet3d.models import builder
from .bevdet import BEVDet
@DETECTORS.register_module()
class BEVDet4D(BEVDet):
r"""BEVDet4D paradigm for multi-camera 3D object detection.
Please refer to the `paper <https://arxiv.org/abs/2203.17054>`_
Args:
pre_process (dict | None): Configuration dict of BEV pre-process net.
align_after_view_transfromation (bool): Whether to align the BEV
Feature after view transformation. By default, the BEV feature of
the previous frame is aligned during the view transformation.
num_adj (int): Number of adjacent frames.
with_prev (bool): Whether to set the BEV feature of previous frame as
all zero. By default, False.
"""
def __init__(self,
pre_process=None,
align_after_view_transfromation=False,
num_adj=1,
with_prev=True,
**kwargs):
super(BEVDet4D, self).__init__(**kwargs)
self.pre_process = pre_process is not None
if self.pre_process:
self.pre_process_net = builder.build_backbone(pre_process)
self.align_after_view_transfromation = align_after_view_transfromation
self.num_frame = num_adj + 1
self.with_prev = with_prev
self.grid = None
def gen_grid(self, input, sensor2keyegos, bda, bda_adj=None):
"""
Args:
input: (B, C, Dy, Dx) bev_feat
sensor2keyegos: List[
curr_sensor-->key_ego: (B, N_views, 4, 4)
prev_sensor-->key_ego: (B, N_views, 4, 4)
]
bda: (B, 3, 3)
bda_adj: None
Returns:
grid: (B, Dy, Dx, 2)
"""
B, C, H, W = input.shape
v = sensor2keyegos[0].shape[0] # N_views
if self.grid is None:
# generate grid
xs = torch.linspace(
0, W - 1, W, dtype=input.dtype,
device=input.device).view(1, W).expand(H, W) # (Dy, Dx)
ys = torch.linspace(
0, H - 1, H, dtype=input.dtype,
device=input.device).view(H, 1).expand(H, W) # (Dy, Dx)
grid = torch.stack((xs, ys, torch.ones_like(xs)), -1) # (Dy, Dx, 3) 3: (x, y, 1)
self.grid = grid
else:
grid = self.grid
# (Dy, Dx, 3) --> (1, Dy, Dx, 3) --> (B, Dy, Dx, 3) --> (B, Dy, Dx, 3, 1)) 3: (grid_x, grid_y, 1)
grid = grid.view(1, H, W, 3).expand(B, H, W, 3).view(B, H, W, 3, 1)
curr_sensor2keyego = sensor2keyegos[0][:, 0:1, :, :] # (B, 1, 4, 4)
prev_sensor2keyego = sensor2keyegos[1][:, 0:1, :, :] # (B, 1, 4, 4)
# add bev data augmentation
bda_ = torch.zeros((B, 1, 4, 4), dtype=grid.dtype).to(grid) # (B, 1, 4, 4)
bda_[:, :, :3, :3] = bda.unsqueeze(1)
bda_[:, :, 3, 3] = 1
curr_sensor2keyego = bda_.matmul(curr_sensor2keyego) # (B, 1, 4, 4)
if bda_adj is not None:
bda_ = torch.zeros((B, 1, 4, 4), dtype=grid.dtype).to(grid)
bda_[:, :, :3, :3] = bda_adj.unsqueeze(1)
bda_[:, :, 3, 3] = 1
prev_sensor2keyego = bda_.matmul(prev_sensor2keyego) # (B, 1, 4, 4)
# transformation from current ego frame to adjacent ego frame
# key_ego --> prev_cam_front --> prev_ego
keyego2adjego = curr_sensor2keyego.matmul(torch.inverse(prev_sensor2keyego))
keyego2adjego = keyego2adjego.unsqueeze(dim=1) # (B, 1, 1, 4, 4)
# (B, 1, 1, 3, 3)
keyego2adjego = keyego2adjego[..., [True, True, False, True], :][..., [True, True, False, True]]
# x = grid_x * vx + x_min; y = grid_y * vy + y_min;
# feat2bev:
# [[vx, 0, x_min],
# [0, vy, y_min],
# [0, 0, 1 ]]
feat2bev = torch.zeros((3, 3), dtype=grid.dtype).to(grid)
feat2bev[0, 0] = self.img_view_transformer.grid_interval[0]
feat2bev[1, 1] = self.img_view_transformer.grid_interval[1]
feat2bev[0, 2] = self.img_view_transformer.grid_lower_bound[0]
feat2bev[1, 2] = self.img_view_transformer.grid_lower_bound[1]
feat2bev[2, 2] = 1
feat2bev = feat2bev.view(1, 3, 3) # (1, 3, 3)
# curr_feat_grid --> key ego --> prev_cam --> prev_ego --> prev_feat_grid
tf = torch.inverse(feat2bev).matmul(keyego2adjego).matmul(feat2bev) # (B, 1, 1, 3, 3)
grid = tf.matmul(grid) # (B, Dy, Dx, 3, 1) 3: (grid_x, grid_y, 1)
normalize_factor = torch.tensor([W - 1.0, H - 1.0],
dtype=input.dtype,
device=input.device) # (2, )
# (B, Dy, Dx, 2)
grid = grid[:, :, :, :2, 0] / normalize_factor.view(1, 1, 1, 2) * 2.0 - 1.0
return grid
@force_fp32()
def shift_feature(self, input, sensor2keyegos, bda, bda_adj=None):
"""
Args:
input: (B, C, Dy, Dx) bev_feat
sensor2keyegos: List[
curr_sensor-->key_ego: (B, N_views, 4, 4)
prev_sensor-->key_ego: (B, N_views, 4, 4)
]
bda: (B, 3, 3)
bda_adj: None
Returns:
output: aligned bev feat (B, C, Dy, Dx).
"""
grid = self.gen_grid(input, sensor2keyegos, bda, bda_adj=bda_adj) # grid: (B, Dy, Dx, 2), 介于(-1, 1)
output = F.grid_sample(input, grid.to(input.dtype), align_corners=True) # (B, C, Dy, Dx)
return output
def prepare_bev_feat(self, img, sensor2egos, ego2globals, intrin, post_rot, post_tran,
bda, mlp_input):
"""
Args:
imgs: (B, N_views, 3, H, W)
sensor2egos: (B, N_views, 4, 4)
ego2globals: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
bda_rot: (B, 3, 3)
mlp_input:
Returns:
bev_feat: (B, C, Dy, Dx)
depth: (B*N, D, fH, fW)
"""
x, _ = self.image_encoder(img) # x: (B, N, C, fH, fW)
# bev_feat: (B, C * Dz(=1), Dy, Dx)
# depth: (B * N, D, fH, fW)
bev_feat, depth = self.img_view_transformer(
[x, sensor2egos, ego2globals, intrin, post_rot, post_tran, bda, mlp_input])
if self.pre_process:
bev_feat = self.pre_process_net(bev_feat)[0] # (B, C, Dy, Dx)
return bev_feat, depth
def extract_img_feat_sequential(self, inputs, feat_prev):
"""
Args:
inputs:
curr_img: (1, N_views, 3, H, W)
sensor2keyegos_curr: (N_prev, N_views, 4, 4)
ego2globals_curr: (N_prev, N_views, 4, 4)
intrins: (1, N_views, 3, 3)
sensor2keyegos_prev: (N_prev, N_views, 4, 4)
ego2globals_prev: (N_prev, N_views, 4, 4)
post_rots: (1, N_views, 3, 3)
post_trans: (1, N_views, 3, )
bda_curr: (N_prev, 3, 3)
feat_prev: (N_prev, C, Dy, Dx)
Returns:
"""
imgs, sensor2keyegos_curr, ego2globals_curr, intrins = inputs[:4]
sensor2keyegos_prev, _, post_rots, post_trans, bda = inputs[4:]
bev_feat_list = []
mlp_input = self.img_view_transformer.get_mlp_input(
sensor2keyegos_curr[0:1, ...], ego2globals_curr[0:1, ...],
intrins, post_rots, post_trans, bda[0:1, ...])
inputs_curr = (imgs, sensor2keyegos_curr[0:1, ...],
ego2globals_curr[0:1, ...], intrins, post_rots,
post_trans, bda[0:1, ...], mlp_input)
# (1, C, Dx, Dy), (1*N, D, fH, fW)
bev_feat, depth = self.prepare_bev_feat(*inputs_curr)
bev_feat_list.append(bev_feat)
# align the feat_prev
_, C, H, W = feat_prev.shape
# feat_prev: (N_prev, C, Dy, Dx)
feat_prev = \
self.shift_feature(feat_prev, # (N_prev, C, Dy, Dx)
[sensor2keyegos_curr, # (N_prev, N_views, 4, 4)
sensor2keyegos_prev], # (N_prev, N_views, 4, 4)
bda # (N_prev, 3, 3)
)
bev_feat_list.append(feat_prev.view(1, (self.num_frame - 1) * C, H, W)) # (1, N_prev*C, Dy, Dx)
bev_feat = torch.cat(bev_feat_list, dim=1) # (1, N_frames*C, Dy, Dx)
x = self.bev_encoder(bev_feat)
return [x], depth
def prepare_inputs(self, img_inputs, stereo=False):
"""
Args:
img_inputs:
imgs: (B, N, 3, H, W) # N = 6 * (N_history + 1)
sensor2egos: (B, N, 4, 4)
ego2globals: (B, N, 4, 4)
intrins: (B, N, 3, 3)
post_rots: (B, N, 3, 3)
post_trans: (B, N, 3)
bda_rot: (B, 3, 3)
stereo: bool
Returns:
imgs: List[(B, N_views, C, H, W), (B, N_views, C, H, W), ...] len = N_frames
sensor2keyegos: List[(B, N_views, 4, 4), (B, N_views, 4, 4), ...]
ego2globals: List[(B, N_views, 4, 4), (B, N_views, 4, 4), ...]
intrins: List[(B, N_views, 3, 3), (B, N_views, 3, 3), ...]
post_rots: List[(B, N_views, 3, 3), (B, N_views, 3, 3), ...]
post_trans: List[(B, N_views, 3), (B, N_views, 3), ...]
bda: (B, 3, 3)
"""
B, N, C, H, W = img_inputs[0].shape
N = N // self.num_frame # N_views = 6
imgs = img_inputs[0].view(B, N, self.num_frame, C, H, W) # (B, N_views, N_frames, C, H, W)
imgs = torch.split(imgs, 1, 2)
imgs = [t.squeeze(2) for t in imgs] # List[(B, N_views, C, H, W), (B, N_views, C, H, W), ...]
sensor2egos, ego2globals, intrins, post_rots, post_trans, bda = \
img_inputs[1:7]
sensor2egos = sensor2egos.view(B, self.num_frame, N, 4, 4)
ego2globals = ego2globals.view(B, self.num_frame, N, 4, 4)
# calculate the transformation from sensor to key ego
# key_ego --> global (B, 1, 1, 4, 4)
keyego2global = ego2globals[:, 0, 0, ...].unsqueeze(1).unsqueeze(1)
# global --> key_ego (B, 1, 1, 4, 4)
global2keyego = torch.inverse(keyego2global.double())
# sensor --> ego --> global --> key_ego
sensor2keyegos = \
global2keyego @ ego2globals.double() @ sensor2egos.double() # (B, N_frames, N_views, 4, 4)
sensor2keyegos = sensor2keyegos.float()
# -------------------- for stereo --------------------------
curr2adjsensor = None
if stereo:
# (B, N_frames, N_views, 4, 4), (B, N_frames, N_views, 4, 4)
sensor2egos_cv, ego2globals_cv = sensor2egos, ego2globals
sensor2egos_curr = \
sensor2egos_cv[:, :self.temporal_frame, ...].double() # (B, N_temporal=2, N_views, 4, 4)
ego2globals_curr = \
ego2globals_cv[:, :self.temporal_frame, ...].double() # (B, N_temporal=2, N_views, 4, 4)
sensor2egos_adj = \
sensor2egos_cv[:, 1:self.temporal_frame + 1, ...].double() # (B, N_temporal=2, N_views, 4, 4)
ego2globals_adj = \
ego2globals_cv[:, 1:self.temporal_frame + 1, ...].double() # (B, N_temporal=2, N_views, 4, 4)
# curr_sensor --> curr_ego --> global --> prev_ego --> prev_sensor
curr2adjsensor = \
torch.inverse(ego2globals_adj @ sensor2egos_adj) \
@ ego2globals_curr @ sensor2egos_curr # (B, N_temporal=2, N_views, 4, 4)
curr2adjsensor = curr2adjsensor.float() # (B, N_temporal=2, N_views, 4, 4)
curr2adjsensor = torch.split(curr2adjsensor, 1, 1)
curr2adjsensor = [p.squeeze(1) for p in curr2adjsensor]
curr2adjsensor.extend([None for _ in range(self.extra_ref_frames)])
# curr2adjsensor: List[(B, N_views, 4, 4), (B, N_views, 4, 4), None]
assert len(curr2adjsensor) == self.num_frame
# -------------------- for stereo --------------------------
extra = [
sensor2keyegos, # (B, N_frames, N_views, 4, 4)
ego2globals, # (B, N_frames, N_views, 4, 4)
intrins.view(B, self.num_frame, N, 3, 3), # (B, N_frames, N_views, 3, 3)
post_rots.view(B, self.num_frame, N, 3, 3), # (B, N_frames, N_views, 3, 3)
post_trans.view(B, self.num_frame, N, 3) # (B, N_frames, N_views, 3)
]
extra = [torch.split(t, 1, 1) for t in extra]
extra = [[p.squeeze(1) for p in t] for t in extra]
sensor2keyegos, ego2globals, intrins, post_rots, post_trans = extra
return imgs, sensor2keyegos, ego2globals, intrins, post_rots, post_trans, \
bda, curr2adjsensor
def extract_img_feat(self,
img_inputs,
img_metas,
pred_prev=False,
sequential=False,
**kwargs):
"""
Args:
img_inputs:
imgs: (B, N, 3, H, W) # N = 6 * (N_history + 1)
sensor2egos: (B, N, 4, 4)
ego2globals: (B, N, 4, 4)
intrins: (B, N, 3, 3)
post_rots: (B, N, 3, 3)
post_trans: (B, N, 3)
bda_rot: (B, 3, 3)
img_metas:
**kwargs:
Returns:
x: [(B, C', H', W'), ]
depth: (B*N_views, D, fH, fW)
"""
if sequential:
return self.extract_img_feat_sequential(img_inputs, kwargs['feat_prev'])
imgs, sensor2keyegos, ego2globals, intrins, post_rots, post_trans, \
bda, _ = self.prepare_inputs(img_inputs)
"""Extract features of images."""
bev_feat_list = []
depth_list = []
key_frame = True # back propagation for key frame only
for img, sensor2keyego, ego2global, intrin, post_rot, post_tran in zip(
imgs, sensor2keyegos, ego2globals, intrins, post_rots, post_trans):
if key_frame or self.with_prev:
if self.align_after_view_transfromation:
sensor2keyego, ego2global = sensor2keyegos[0], ego2globals[0]
mlp_input = self.img_view_transformer.get_mlp_input(
sensor2keyegos[0], ego2globals[0], intrin, post_rot, post_tran, bda) # (B, N_views, 27)
inputs_curr = (img, sensor2keyego, ego2global, intrin, post_rot,
post_tran, bda, mlp_input)
if key_frame:
# bev_feat: (B, C, Dy, Dx)
# depth: (B*N_views, D, fH, fW)
bev_feat, depth = self.prepare_bev_feat(*inputs_curr)
else:
with torch.no_grad():
bev_feat, depth = self.prepare_bev_feat(*inputs_curr)
else:
# https://github.com/HuangJunJie2017/BEVDet/issues/275
bev_feat = torch.zeros_like(bev_feat_list[0])
depth = None
bev_feat_list.append(bev_feat)
depth_list.append(depth)
key_frame = False
# bev_feat_list: List[(B, C, Dy, Dx), (B, C, Dy, Dx), ...]
# depth_list: List[(B*N_views, D, fH, fW), (B*N_views, D, fH, fW), ...]
if pred_prev:
assert self.align_after_view_transfromation
assert sensor2keyegos[0].shape[0] == 1 # batch_size = 1
feat_prev = torch.cat(bev_feat_list[1:], dim=0)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
ego2globals_curr = \
ego2globals[0].repeat(self.num_frame - 1, 1, 1, 1)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
sensor2keyegos_curr = \
sensor2keyegos[0].repeat(self.num_frame - 1, 1, 1, 1)
ego2globals_prev = torch.cat(ego2globals[1:], dim=0) # (N_prev, N_views, 4, 4)
sensor2keyegos_prev = torch.cat(sensor2keyegos[1:], dim=0) # (N_prev, N_views, 4, 4)
bda_curr = bda.repeat(self.num_frame - 1, 1, 1) # (N_prev, 3, 3)
return feat_prev, [imgs[0], # (1, N_views, 3, H, W)
sensor2keyegos_curr, # (N_prev, N_views, 4, 4)
ego2globals_curr, # (N_prev, N_views, 4, 4)
intrins[0], # (1, N_views, 3, 3)
sensor2keyegos_prev, # (N_prev, N_views, 4, 4)
ego2globals_prev, # (N_prev, N_views, 4, 4)
post_rots[0], # (1, N_views, 3, 3)
post_trans[0], # (1, N_views, 3, )
bda_curr] # (N_prev, 3, 3)
if self.align_after_view_transfromation:
for adj_id in range(1, self.num_frame):
bev_feat_list[adj_id] = self.shift_feature(
bev_feat_list[adj_id], # (B, C, Dy, Dx)
[sensor2keyegos[0], # (B, N_views, 4, 4)
sensor2keyegos[adj_id] # (B, N_views, 4, 4)
],
bda # (B, 3, 3)
) # (B, C, Dy, Dx)
bev_feat = torch.cat(bev_feat_list, dim=1) # (B, N_frames*C, Dy, Dx)
x = self.bev_encoder(bev_feat)
return [x], depth_list[0]
# Copyright (c) Phigent Robotics. All rights reserved.
from ...ops import TRTBEVPoolv2
from .bevdet import BEVDet
from .bevdepth import BEVDepth
from .bevdepth4d import BEVDepth4D
from .bevstereo4d import BEVStereo4D
from mmdet3d.models import DETECTORS
from mmdet3d.models.builder import build_head
import torch.nn.functional as F
from mmdet3d.core import bbox3d2result
import numpy as np
from multiprocessing.dummy import Pool as ThreadPool
from ...ops import nearest_assign
# pool = ThreadPool(processes=4) # 创建线程池
# for pano
grid_config_occ = {
'x': [-40, 40, 0.4],
'y': [-40, 40, 0.4],
'z': [-1, 5.4, 6.4],
'depth': [1.0, 45.0, 1.0],
}
# det
det_class_name = ['car', 'truck', 'trailer', 'bus', 'construction_vehicle',
'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone',
'barrier']
# occ
occ_class_names = [
'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',
'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',
'driveable_surface', 'other_flat', 'sidewalk',
'terrain', 'manmade', 'vegetation', 'free'
]
det_ind = [2, 3, 4, 5, 6, 7, 9, 10]
occ_ind = [5, 3, 0, 4, 6, 7, 2, 1]
detind2occind = {
0:4,
1:10,
2:9,
3:3,
4:5,
5:2,
6:6,
7:7,
8:8,
9:1,
}
occind2detind = {
4:0,
10:1,
9:2,
3:3,
5:4,
2:5,
6:6,
7:7,
8:8,
1:9,
}
occind2detind_cuda = [-1, -1, 5, 3, 0, 4, 6, 7, -1, 2, 1]
inst_occ = np.ones([200, 200, 16])*0
import torch
X1, Y1, Z1 = 200, 200, 16
coords_x = torch.arange(X1).float()
coords_y = torch.arange(Y1).float()
coords_z = torch.arange(Z1).float()
coords = torch.stack(torch.meshgrid([coords_x, coords_y, coords_z])).permute(1, 2, 3, 0) # W, H, D, 3
# coords = coords.cpu().numpy()
st = [grid_config_occ['x'][0], grid_config_occ['y'][0], grid_config_occ['z'][0]]
sx = [grid_config_occ['x'][2], grid_config_occ['y'][2], 0.4]
@DETECTORS.register_module()
class BEVDetOCC(BEVDet):
def __init__(self,
occ_head=None,
upsample=False,
**kwargs):
super(BEVDetOCC, self).__init__(**kwargs)
self.occ_head = build_head(occ_head)
self.pts_bbox_head = None
self.upsample = upsample
#@torch.compile
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
losses = dict()
voxel_semantics = kwargs['voxel_semantics'] # (B, Dx, Dy, Dz)
mask_camera = kwargs['mask_camera'] # (B, Dx, Dy, Dz)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
loss_occ = self.forward_occ_train(occ_bev_feature, voxel_semantics, mask_camera)
losses.update(loss_occ)
return losses
def forward_occ_train(self, img_feats, voxel_semantics, mask_camera):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs = self.occ_head(img_feats)
# assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
loss_occ = self.occ_head.loss(
outs, # (B, Dx, Dy, Dz, n_cls)
voxel_semantics, # (B, Dx, Dy, Dz)
mask_camera, # (B, Dx, Dy, Dz)
)
return loss_occ
def simple_test(self,
points,
img_metas,
img=None,
rescale=False,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, _, _ = self.extract_feat(
points, img_inputs=img, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
occ_list = self.simple_test_occ(occ_bev_feature, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_list
def simple_test_occ(self, img_feats, img_metas=None):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs = self.occ_head(img_feats)
if not hasattr(self.occ_head, "get_occ_gpu"):
occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
else:
occ_preds = self.occ_head.get_occ_gpu(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_preds
def forward_dummy(self,
points=None,
img_metas=None,
img_inputs=None,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs = self.occ_head(occ_bev_feature)
return outs
@DETECTORS.register_module()
class BEVDepthOCC(BEVDepth):
def __init__(self,
occ_head=None,
upsample=False,
**kwargs):
super(BEVDepthOCC, self).__init__(**kwargs)
self.occ_head = build_head(occ_head)
self.pts_bbox_head = None
self.upsample = upsample
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
losses = dict()
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses['loss_depth'] = loss_depth
voxel_semantics = kwargs['voxel_semantics'] # (B, Dx, Dy, Dz)
mask_camera = kwargs['mask_camera'] # (B, Dx, Dy, Dz)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
loss_occ = self.forward_occ_train(occ_bev_feature, voxel_semantics, mask_camera)
losses.update(loss_occ)
return losses
def forward_occ_train(self, img_feats, voxel_semantics, mask_camera):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs = self.occ_head(img_feats)
# assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
loss_occ = self.occ_head.loss(
outs, # (B, Dx, Dy, Dz, n_cls)
voxel_semantics, # (B, Dx, Dy, Dz)
mask_camera, # (B, Dx, Dy, Dz)
)
return loss_occ
def simple_test(self,
points,
img_metas,
img=None,
rescale=False,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, _, _ = self.extract_feat(
points, img_inputs=img, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
occ_list = self.simple_test_occ(occ_bev_feature, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_list
def simple_test_occ(self, img_feats, img_metas=None):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs = self.occ_head(img_feats)
# occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
occ_preds = self.occ_head.get_occ_gpu(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_preds
def forward_dummy(self,
points=None,
img_metas=None,
img_inputs=None,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs = self.occ_head(occ_bev_feature)
return outs
@DETECTORS.register_module()
class BEVDepthPano(BEVDepthOCC):
def __init__(self,
aux_centerness_head=None,
**kwargs):
super(BEVDepthPano, self).__init__(**kwargs)
self.aux_centerness_head = None
if aux_centerness_head:
train_cfg = kwargs['train_cfg']
test_cfg = kwargs['test_cfg']
pts_train_cfg = train_cfg.pts if train_cfg else None
aux_centerness_head.update(train_cfg=pts_train_cfg)
pts_test_cfg = test_cfg.pts if test_cfg else None
aux_centerness_head.update(test_cfg=pts_test_cfg)
self.aux_centerness_head = build_head(aux_centerness_head)
if 'inst_class_ids' in kwargs:
self.inst_class_ids = kwargs['inst_class_ids']
else:
self.inst_class_ids = [2, 3, 4, 5, 6, 7, 9, 10]
X1, Y1, Z1 = 200, 200, 16
coords_x = torch.arange(X1).float()
coords_y = torch.arange(Y1).float()
coords_z = torch.arange(Z1).float()
self.coords = torch.stack(torch.meshgrid([coords_x, coords_y, coords_z])).permute(1, 2, 3, 0) # W, H, D, 3
self.st = torch.tensor([grid_config_occ['x'][0], grid_config_occ['y'][0], grid_config_occ['z'][0]])
self.sx = torch.tensor([grid_config_occ['x'][2], grid_config_occ['y'][2], 0.4])
self.is_to_d = False
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
losses = dict()
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses['loss_depth'] = loss_depth
voxel_semantics = kwargs['voxel_semantics'] # (B, Dx, Dy, Dz)
mask_camera = kwargs['mask_camera'] # (B, Dx, Dy, Dz)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
loss_occ = self.forward_occ_train(occ_bev_feature, voxel_semantics, mask_camera)
losses.update(loss_occ)
losses_aux_centerness = self.forward_aux_centerness_train([occ_bev_feature], gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore)
losses.update(losses_aux_centerness)
return losses
def forward_aux_centerness_train(self,
pts_feats,
gt_bboxes_3d,
gt_labels_3d,
img_metas,
gt_bboxes_ignore=None):
outs = self.aux_centerness_head(pts_feats)
loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
losses = self.aux_centerness_head.loss(*loss_inputs)
return losses
def simple_test_aux_centerness(self, x, img_metas, rescale=False, **kwargs):
"""Test function of point cloud branch."""
# outs = self.aux_centerness_head(x)
tx = self.aux_centerness_head.shared_conv(x[0]) # (B, C'=share_conv_channel, H, W)
outs_inst_center_reg = self.aux_centerness_head.task_heads[0].reg(tx)
outs_inst_center_height = self.aux_centerness_head.task_heads[0].height(tx)
outs_inst_center_heatmap = self.aux_centerness_head.task_heads[0].heatmap(tx)
outs = ([{
"reg" : outs_inst_center_reg,
"height" : outs_inst_center_height,
"heatmap" : outs_inst_center_heatmap,
}],)
# # bbox_list = self.aux_centerness_head.get_bboxes(
# # outs, img_metas, rescale=rescale)
# # bbox_results = [
# # bbox3d2result(bboxes, scores, labels)
# # for bboxes, scores, labels in bbox_list
# # ]
ins_cen_list = self.aux_centerness_head.get_centers(
outs, img_metas, rescale=rescale)
# return bbox_results, ins_cen_list
return None, ins_cen_list
def simple_test(self,
points,
img_metas,
img=None,
rescale=False,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
result_list = [dict() for _ in range(len(img_metas))]
img_feats, _, _ = self.extract_feat(
points, img_inputs=img, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
w_pano = kwargs['w_pano'] if 'w_pano' in kwargs else True
if w_pano == True:
bbox_pts, ins_cen_list = self.simple_test_aux_centerness([occ_bev_feature], img_metas, rescale=rescale, **kwargs)
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
occ_list = self.simple_test_occ(occ_bev_feature, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
for result_dict, occ_pred in zip(result_list, occ_list):
result_dict['pred_occ'] = occ_pred
w_panoproc = kwargs['w_panoproc'] if 'w_panoproc' in kwargs else True # 37.53 ms
if w_panoproc == True:
# # for pano
inst_xyz = ins_cen_list[0][0]
if self.is_to_d == False:
self.st = self.st.to(inst_xyz)
self.sx = self.sx.to(inst_xyz)
self.coords = self.coords.to(inst_xyz)
self.is_to_d = True
inst_xyz = ((inst_xyz - self.st) / self.sx).int()
inst_cls = ins_cen_list[2][0].int()
inst_num = 18 # 37.62 ms
# inst_occ = torch.tensor(occ_pred).to(inst_cls)
# inst_occ = occ_pred.clone().detach()
inst_occ = occ_pred.clone().detach() # 37.61 ms
if len(inst_cls) > 0:
cls_sort, indices = inst_cls.sort()
l2s = {}
if len(inst_cls) == 1:
l2s[cls_sort[0].item()] = 0
l2s[cls_sort[0].item()] = 0
# # tind_list = cls_sort[1:] - cls_sort[:-1]!=0
# # for tind in range(len(tind_list)):
# # if tind_list[tind] == True:
# # l2s[cls_sort[1+tind].item()] = tind + 1
tind_list = (cls_sort[1:] - cls_sort[:-1])!=0
if tind_list.__len__() > 0:
for tind in torch.range(0,len(tind_list)-1)[tind_list]:
l2s[cls_sort[1+int(tind.item())].item()] = int(tind.item()) + 1
is_cuda = True
# is_cuda = False
if is_cuda == True:
inst_id_list = indices + inst_num
l2s_key = indices.new_tensor([detind2occind[k] for k in l2s.keys()]).to(torch.int)
inst_occ = nearest_assign(
occ_pred.to(torch.int),
l2s_key.to(torch.int),
indices.new_tensor(occind2detind_cuda).to(torch.int),
inst_cls.to(torch.int),
inst_xyz.to(torch.int),
inst_id_list.to(torch.int)
)
else:
for cls_label_num_in_occ in self.inst_class_ids:
mask = occ_pred == cls_label_num_in_occ
if mask.sum() == 0:
continue
else:
cls_label_num_in_inst = occind2detind[cls_label_num_in_occ]
select_mask = inst_cls==cls_label_num_in_inst
if sum(select_mask) > 0:
indices = self.coords[mask]
inst_index_same_cls = inst_xyz[select_mask]
select_ind = ((indices[:,None,:] - inst_index_same_cls[None,:,:])**2).sum(-1).argmin(axis=1).int()
inst_occ[mask] = select_ind + inst_num + l2s[cls_label_num_in_inst]
result_list[0]['pano_inst'] = inst_occ #.cpu().numpy()
return result_list
@DETECTORS.register_module()
class BEVDepth4DOCC(BEVDepth4D):
def __init__(self,
occ_head=None,
upsample=False,
**kwargs):
super(BEVDepth4DOCC, self).__init__(**kwargs)
self.occ_head = build_head(occ_head)
self.pts_bbox_head = None
self.upsample = upsample
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
losses = dict()
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses['loss_depth'] = loss_depth
voxel_semantics = kwargs['voxel_semantics'] # (B, Dx, Dy, Dz)
mask_camera = kwargs['mask_camera'] # (B, Dx, Dy, Dz)
loss_occ = self.forward_occ_train(img_feats[0], voxel_semantics, mask_camera)
losses.update(loss_occ)
return losses
def forward_occ_train(self, img_feats, voxel_semantics, mask_camera):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs = self.occ_head(img_feats)
assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
loss_occ = self.occ_head.loss(
outs, # (B, Dx, Dy, Dz, n_cls)
voxel_semantics, # (B, Dx, Dy, Dz)
mask_camera, # (B, Dx, Dy, Dz)
)
return loss_occ
def simple_test(self,
points,
img_metas,
img=None,
rescale=False,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, _, _ = self.extract_feat(
points, img_inputs=img, img_metas=img_metas, **kwargs)
occ_list = self.simple_test_occ(img_feats[0], img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_list
def simple_test_occ(self, img_feats, img_metas=None):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs = self.occ_head(img_feats)
# occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
occ_preds = self.occ_head.get_occ_gpu(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_preds
def forward_dummy(self,
points=None,
img_metas=None,
img_inputs=None,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs = self.occ_head(occ_bev_feature)
return outs
@DETECTORS.register_module()
class BEVDepth4DPano(BEVDepth4DOCC):
def __init__(self,
aux_centerness_head=None,
**kwargs):
super(BEVDepth4DPano, self).__init__(**kwargs)
self.aux_centerness_head = None
if aux_centerness_head:
train_cfg = kwargs['train_cfg']
test_cfg = kwargs['test_cfg']
pts_train_cfg = train_cfg.pts if train_cfg else None
aux_centerness_head.update(train_cfg=pts_train_cfg)
pts_test_cfg = test_cfg.pts if test_cfg else None
aux_centerness_head.update(test_cfg=pts_test_cfg)
self.aux_centerness_head = build_head(aux_centerness_head)
if 'inst_class_ids' in kwargs:
self.inst_class_ids = kwargs['inst_class_ids']
else:
self.inst_class_ids = [2, 3, 4, 5, 6, 7, 9, 10]
X1, Y1, Z1 = 200, 200, 16
coords_x = torch.arange(X1).float()
coords_y = torch.arange(Y1).float()
coords_z = torch.arange(Z1).float()
self.coords = torch.stack(torch.meshgrid([coords_x, coords_y, coords_z])).permute(1, 2, 3, 0) # W, H, D, 3
self.st = torch.tensor([grid_config_occ['x'][0], grid_config_occ['y'][0], grid_config_occ['z'][0]])
self.sx = torch.tensor([grid_config_occ['x'][2], grid_config_occ['y'][2], 0.4])
self.is_to_d = False
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
losses = dict()
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses['loss_depth'] = loss_depth
voxel_semantics = kwargs['voxel_semantics'] # (B, Dx, Dy, Dz)
mask_camera = kwargs['mask_camera'] # (B, Dx, Dy, Dz)
loss_occ = self.forward_occ_train(img_feats[0], voxel_semantics, mask_camera)
losses.update(loss_occ)
losses_aux_centerness = self.forward_aux_centerness_train([img_feats[0]], gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore)
losses.update(losses_aux_centerness)
return losses
def forward_aux_centerness_train(self,
pts_feats,
gt_bboxes_3d,
gt_labels_3d,
img_metas,
gt_bboxes_ignore=None):
outs = self.aux_centerness_head(pts_feats)
loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
losses = self.aux_centerness_head.loss(*loss_inputs)
return losses
def simple_test_aux_centerness(self, x, img_metas, rescale=False, **kwargs):
"""Test function of point cloud branch."""
outs = self.aux_centerness_head(x)
bbox_list = self.aux_centerness_head.get_bboxes(
outs, img_metas, rescale=rescale)
bbox_results = [
bbox3d2result(bboxes, scores, labels)
for bboxes, scores, labels in bbox_list
]
ins_cen_list = self.aux_centerness_head.get_centers(
outs, img_metas, rescale=rescale)
return bbox_results, ins_cen_list
def simple_test(self,
points,
img_metas,
img=None,
rescale=False,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
result_list = [dict() for _ in range(len(img_metas))]
img_feats, _, _ = self.extract_feat(
points, img_inputs=img, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
w_pano = kwargs['w_pano'] if 'w_pano' in kwargs else True
if w_pano == True:
bbox_pts, ins_cen_list = self.simple_test_aux_centerness([occ_bev_feature], img_metas, rescale=rescale, **kwargs)
occ_list = self.simple_test_occ(occ_bev_feature, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
for result_dict, occ_pred in zip(result_list, occ_list):
result_dict['pred_occ'] = occ_pred
w_panoproc = kwargs['w_panoproc'] if 'w_panoproc' in kwargs else True
if w_panoproc == True:
# # for pano
inst_xyz = ins_cen_list[0][0]
if self.is_to_d == False:
self.st = self.st.to(inst_xyz)
self.sx = self.sx.to(inst_xyz)
self.coords = self.coords.to(inst_xyz)
self.is_to_d = True
inst_xyz = ((inst_xyz - self.st) / self.sx).int()
inst_cls = ins_cen_list[2][0].int()
inst_num = 18 # 37.62 ms
# inst_occ = torch.tensor(occ_pred).to(inst_cls)
# inst_occ = occ_pred.clone().detach()
inst_occ = occ_pred.clone().detach() # 37.61 ms
if len(inst_cls) > 0:
cls_sort, indices = inst_cls.sort()
l2s = {}
if len(inst_cls) == 1:
l2s[cls_sort[0].item()] = 0
l2s[cls_sort[0].item()] = 0
# # tind_list = cls_sort[1:] - cls_sort[:-1]!=0
# # for tind in range(len(tind_list)):
# # if tind_list[tind] == True:
# # l2s[cls_sort[1+tind].item()] = tind + 1
tind_list = (cls_sort[1:] - cls_sort[:-1])!=0
if tind_list.__len__() > 0:
for tind in torch.range(0,len(tind_list)-1)[tind_list]:
l2s[cls_sort[1+int(tind.item())].item()] = int(tind.item()) + 1
is_cuda = True
# is_cuda = False
if is_cuda == True:
inst_id_list = indices + inst_num
l2s_key = indices.new_tensor([detind2occind[k] for k in l2s.keys()]).to(torch.int)
inst_occ = nearest_assign(
occ_pred.to(torch.int),
l2s_key.to(torch.int),
indices.new_tensor(occind2detind_cuda).to(torch.int),
inst_cls.to(torch.int),
inst_xyz.to(torch.int),
inst_id_list.to(torch.int)
)
else:
for cls_label_num_in_occ in self.inst_class_ids:
mask = occ_pred == cls_label_num_in_occ
if mask.sum() == 0:
continue
else:
cls_label_num_in_inst = occind2detind[cls_label_num_in_occ]
select_mask = inst_cls==cls_label_num_in_inst
if sum(select_mask) > 0:
indices = self.coords[mask]
inst_index_same_cls = inst_xyz[select_mask]
select_ind = ((indices[:,None,:] - inst_index_same_cls[None,:,:])**2).sum(-1).argmin(axis=1).int()
inst_occ[mask] = select_ind + inst_num + l2s[cls_label_num_in_inst]
result_list[0]['pano_inst'] = inst_occ #.cpu().numpy()
return result_list
@DETECTORS.register_module()
class BEVStereo4DOCC(BEVStereo4D):
def __init__(self,
occ_head=None,
upsample=False,
**kwargs):
super(BEVStereo4DOCC, self).__init__(**kwargs)
self.occ_head = build_head(occ_head)
self.pts_bbox_head = None
self.upsample = upsample
def forward_train(self,
points=None,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
gt_labels=None,
gt_bboxes=None,
img_inputs=None,
proposals=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
Defaults to None.
img_metas (list[dict], optional): Meta information of each sample.
Defaults to None.
gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
Ground truth 3D boxes. Defaults to None.
gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
of 3D boxes. Defaults to None.
gt_labels (list[torch.Tensor], optional): Ground truth labels
of 2D boxes in images. Defaults to None.
gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
images. Defaults to None.
img (torch.Tensor optional): Images of each sample with shape
(N, C, H, W). Defaults to None.
proposals ([list[torch.Tensor], optional): Predicted proposals
used for training Fast RCNN. Defaults to None.
gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
2D boxes in images to be ignored. Defaults to None.
Returns:
dict: Losses of different branches.
"""
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
gt_depth = kwargs['gt_depth'] # (B, N_views, img_H, img_W)
losses = dict()
loss_depth = self.img_view_transformer.get_depth_loss(gt_depth, depth)
losses['loss_depth'] = loss_depth
voxel_semantics = kwargs['voxel_semantics'] # (B, Dx, Dy, Dz)
mask_camera = kwargs['mask_camera'] # (B, Dx, Dy, Dz)
loss_occ = self.forward_occ_train(img_feats[0], voxel_semantics, mask_camera)
losses.update(loss_occ)
return losses
def forward_occ_train(self, img_feats, voxel_semantics, mask_camera):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
outs = self.occ_head(img_feats)
assert voxel_semantics.min() >= 0 and voxel_semantics.max() <= 17
loss_occ = self.occ_head.loss(
outs, # (B, Dx, Dy, Dz, n_cls)
voxel_semantics, # (B, Dx, Dy, Dz)
mask_camera, # (B, Dx, Dy, Dz)
)
return loss_occ
def simple_test(self,
points,
img_metas,
img=None,
rescale=False,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, _, _ = self.extract_feat(
points, img_inputs=img, img_metas=img_metas, **kwargs)
occ_list = self.simple_test_occ(img_feats[0], img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_list
def simple_test_occ(self, img_feats, img_metas=None):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx) / (B, C, Dy, Dx)
img_metas:
Returns:
occ_preds: List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
outs = self.occ_head(img_feats)
# occ_preds = self.occ_head.get_occ(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
occ_preds = self.occ_head.get_occ_gpu(outs, img_metas) # List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
return occ_preds
def forward_dummy(self,
points=None,
img_metas=None,
img_inputs=None,
**kwargs):
# img_feats: List[(B, C, Dz, Dy, Dx)/(B, C, Dy, Dx) , ]
# pts_feats: None
# depth: (B*N_views, D, fH, fW)
img_feats, pts_feats, depth = self.extract_feat(
points, img_inputs=img_inputs, img_metas=img_metas, **kwargs)
occ_bev_feature = img_feats[0]
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs = self.occ_head(occ_bev_feature)
return outs
@DETECTORS.register_module()
class BEVDetOCCTRT(BEVDetOCC):
def __init__(self,
wocc=True,
wdet3d=True,
uni_train=True,
**kwargs):
super(BEVDetOCCTRT, self).__init__(**kwargs)
self.wocc = wocc
self.wdet3d = wdet3d
self.uni_train = uni_train
def result_serialize(self, outs_det3d=None, outs_occ=None):
outs_ = []
if outs_det3d is not None:
for out in outs_det3d:
for key in ['reg', 'height', 'dim', 'rot', 'vel', 'heatmap']:
outs_.append(out[0][key])
if outs_occ is not None:
outs_.append(outs_occ)
return outs_
def result_deserialize(self, outs):
outs_ = []
keys = ['reg', 'height', 'dim', 'rot', 'vel', 'heatmap']
for head_id in range(len(outs) // 6):
outs_head = [dict()]
for kid, key in enumerate(keys):
outs_head[0][key] = outs[head_id * 6 + kid]
outs_.append(outs_head)
return outs_
def forward_part1(
self,
img,
):
x = self.img_backbone(img)
x = self.img_neck(x)
x = self.img_view_transformer.depth_net(x[0])
depth = x[:, :self.img_view_transformer.D].softmax(dim=1)
tran_feat = x[:, self.img_view_transformer.D:(
self.img_view_transformer.D +
self.img_view_transformer.out_channels)]
tran_feat = tran_feat.permute(0, 2, 3, 1)
# depth = depth.reshape(-1)
# tran_feat = tran_feat.flatten(0,2)
return tran_feat.flatten(0,2), depth.reshape(-1)
def forward_part2(
self,
tran_feat,
depth,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
):
tran_feat = tran_feat.reshape(6, 16, 44, 64)
depth = depth.reshape(6, 16, 44, 44)
x = TRTBEVPoolv2.apply(depth.contiguous(), tran_feat.contiguous(),
ranks_depth, ranks_feat, ranks_bev,
interval_starts, interval_lengths,
int(self.img_view_transformer.grid_size[0].item()),
int(self.img_view_transformer.grid_size[1].item()),
int(self.img_view_transformer.grid_size[2].item())
) # -> [1, 64, 200, 200]
return x.reshape(-1)
def forward_part3(
self,
x
):
x = x.reshape(1, 200, 200, 64)
x = x.permute(0, 3, 1, 2).contiguous()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature = self.img_bev_encoder_backbone(x)
occ_bev_feature = self.img_bev_encoder_neck(bev_feature)
outs_occ = None
if self.wocc == True:
if self.uni_train == True:
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs_occ = self.occ_head(occ_bev_feature)
outs_det3d = None
if self.wdet3d == True:
outs_det3d = self.pts_bbox_head([occ_bev_feature])
outs = self.result_serialize(outs_det3d, outs_occ)
return outs
def forward_ori(
self,
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
):
x = self.img_backbone(img)
x = self.img_neck(x)
x = self.img_view_transformer.depth_net(x[0])
depth = x[:, :self.img_view_transformer.D].softmax(dim=1)
tran_feat = x[:, self.img_view_transformer.D:(
self.img_view_transformer.D +
self.img_view_transformer.out_channels)]
tran_feat = tran_feat.permute(0, 2, 3, 1)
x = TRTBEVPoolv2.apply(depth.contiguous(), tran_feat.contiguous(),
ranks_depth, ranks_feat, ranks_bev,
interval_starts, interval_lengths,
int(self.img_view_transformer.grid_size[0].item()),
int(self.img_view_transformer.grid_size[1].item()),
int(self.img_view_transformer.grid_size[2].item())
)
x = x.permute(0, 3, 1, 2).contiguous()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature = self.img_bev_encoder_backbone(x)
occ_bev_feature = self.img_bev_encoder_neck(bev_feature)
outs_occ = None
if self.wocc == True:
if self.uni_train == True:
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs_occ = self.occ_head(occ_bev_feature)
outs_det3d = None
if self.wdet3d == True:
outs_det3d = self.pts_bbox_head([occ_bev_feature])
outs = self.result_serialize(outs_det3d, outs_occ)
return outs
def forward_with_argmax(
self,
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
):
outs = self.forward_ori(
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
)
pred_occ_label = outs[0].argmax(-1)
return pred_occ_label
def get_bev_pool_input(self, input):
input = self.prepare_inputs(input)
coor = self.img_view_transformer.get_lidar_coor(*input[1:7])
return self.img_view_transformer.voxel_pooling_prepare_v2(coor)
@DETECTORS.register_module()
class BEVDepthOCCTRT(BEVDetOCC):
def __init__(self,
wocc=True,
wdet3d=True,
uni_train=True,
**kwargs):
super(BEVDepthOCCTRT, self).__init__(**kwargs)
self.wocc = wocc
self.wdet3d = wdet3d
self.uni_train = uni_train
def result_serialize(self, outs_det3d=None, outs_occ=None):
outs_ = []
if outs_det3d is not None:
for out in outs_det3d:
for key in ['reg', 'height', 'dim', 'rot', 'vel', 'heatmap']:
outs_.append(out[0][key])
if outs_occ is not None:
outs_.append(outs_occ)
return outs_
def result_deserialize(self, outs):
outs_ = []
keys = ['reg', 'height', 'dim', 'rot', 'vel', 'heatmap']
for head_id in range(len(outs) // 6):
outs_head = [dict()]
for kid, key in enumerate(keys):
outs_head[0][key] = outs[head_id * 6 + kid]
outs_.append(outs_head)
return outs_
def forward_ori(
self,
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
mlp_input,
):
x = self.img_backbone(img)
x = self.img_neck(x)
x = self.img_view_transformer.depth_net(x[0], mlp_input)
depth = x[:, :self.img_view_transformer.D].softmax(dim=1)
tran_feat = x[:, self.img_view_transformer.D:(
self.img_view_transformer.D +
self.img_view_transformer.out_channels)]
tran_feat = tran_feat.permute(0, 2, 3, 1)
x = TRTBEVPoolv2.apply(depth.contiguous(), tran_feat.contiguous(),
ranks_depth, ranks_feat, ranks_bev,
interval_starts, interval_lengths,
int(self.img_view_transformer.grid_size[0].item()),
int(self.img_view_transformer.grid_size[1].item()),
int(self.img_view_transformer.grid_size[2].item())
)
x = x.permute(0, 3, 1, 2).contiguous()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature = self.img_bev_encoder_backbone(x)
occ_bev_feature = self.img_bev_encoder_neck(bev_feature)
outs_occ = None
if self.wocc == True:
if self.uni_train == True:
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs_occ = self.occ_head(occ_bev_feature)
outs_det3d = None
if self.wdet3d == True:
outs_det3d = self.pts_bbox_head([occ_bev_feature])
outs = self.result_serialize(outs_det3d, outs_occ)
return outs
def forward_with_argmax(
self,
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
mlp_input,
):
outs = self.forward_ori(
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
mlp_input,
)
pred_occ_label = outs[0].argmax(-1)
return pred_occ_label
def get_bev_pool_input(self, input):
input = self.prepare_inputs(input)
coor = self.img_view_transformer.get_lidar_coor(*input[1:7])
mlp_input = self.img_view_transformer.get_mlp_input(*input[1:7])
# sensor2keyegos, ego2globals, intrins, post_rots, post_trans, bda) # (B, N_views, 27)
return self.img_view_transformer.voxel_pooling_prepare_v2(coor), mlp_input
@DETECTORS.register_module()
class BEVDepthPanoTRT(BEVDepthPano):
def __init__(self,
wocc=True,
wdet3d=True,
uni_train=True,
**kwargs):
super(BEVDepthPanoTRT, self).__init__(**kwargs)
self.wocc = wocc
self.wdet3d = wdet3d
self.uni_train = uni_train
def result_serialize(self, outs_det3d=None, outs_occ=None):
outs_ = []
if outs_det3d is not None:
for out in outs_det3d:
for key in ['reg', 'height', 'dim', 'rot', 'vel', 'heatmap']:
outs_.append(out[0][key])
if outs_occ is not None:
outs_.append(outs_occ)
return outs_
def result_deserialize(self, outs):
outs_ = []
keys = ['reg', 'height', 'dim', 'rot', 'vel', 'heatmap']
for head_id in range(len(outs) // 6):
outs_head = [dict()]
for kid, key in enumerate(keys):
outs_head[0][key] = outs[head_id * 6 + kid]
outs_.append(outs_head)
return outs_
def forward_part1(
self,
img,
mlp_input,
):
x = self.img_backbone(img)
x = self.img_neck(x)
x = self.img_view_transformer.depth_net(x[0], mlp_input)
depth = x[:, :self.img_view_transformer.D].softmax(dim=1)
tran_feat = x[:, self.img_view_transformer.D:(
self.img_view_transformer.D +
self.img_view_transformer.out_channels)]
tran_feat = tran_feat.permute(0, 2, 3, 1)
# depth = depth.reshape(-1)
# tran_feat = tran_feat.flatten(0,2)
return tran_feat.flatten(0,2), depth.reshape(-1)
def forward_part2(
self,
tran_feat,
depth,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
):
tran_feat = tran_feat.reshape(6, 16, 44, 64)
depth = depth.reshape(6, 16, 44, 44)
x = TRTBEVPoolv2.apply(depth.contiguous(), tran_feat.contiguous(),
ranks_depth, ranks_feat, ranks_bev,
interval_starts, interval_lengths,
int(self.img_view_transformer.grid_size[0].item()),
int(self.img_view_transformer.grid_size[1].item()),
int(self.img_view_transformer.grid_size[2].item())
) # -> [1, 64, 200, 200]
return x.reshape(-1)
def forward_part3(
self,
x
):
x = x.reshape(1, 200, 200, 64)
x = x.permute(0, 3, 1, 2).contiguous()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature = self.img_bev_encoder_backbone(x)
occ_bev_feature = self.img_bev_encoder_neck(bev_feature)
outs_occ = None
if self.wocc == True:
if self.uni_train == True:
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs_occ = self.occ_head(occ_bev_feature)
outs_det3d = None
if self.wdet3d == True:
outs_det3d = self.pts_bbox_head([occ_bev_feature])
outs = self.result_serialize(outs_det3d, outs_occ)
# outs_inst_center = self.aux_centerness_head([occ_bev_feature])
x = self.aux_centerness_head.shared_conv(occ_bev_feature) # (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
outs_inst_center_reg = self.aux_centerness_head.task_heads[0].reg(x)
outs.append(outs_inst_center_reg)
outs_inst_center_height = self.aux_centerness_head.task_heads[0].height(x)
outs.append(outs_inst_center_height)
outs_inst_center_heatmap = self.aux_centerness_head.task_heads[0].heatmap(x)
outs.append(outs_inst_center_heatmap)
def forward_ori(
self,
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
mlp_input,
):
x = self.img_backbone(img)
x = self.img_neck(x)
x = self.img_view_transformer.depth_net(x[0], mlp_input)
depth = x[:, :self.img_view_transformer.D].softmax(dim=1)
tran_feat = x[:, self.img_view_transformer.D:(
self.img_view_transformer.D +
self.img_view_transformer.out_channels)]
tran_feat = tran_feat.permute(0, 2, 3, 1)
x = TRTBEVPoolv2.apply(depth.contiguous(), tran_feat.contiguous(),
ranks_depth, ranks_feat, ranks_bev,
interval_starts, interval_lengths,
int(self.img_view_transformer.grid_size[0].item()),
int(self.img_view_transformer.grid_size[1].item()),
int(self.img_view_transformer.grid_size[2].item())
)
x = x.permute(0, 3, 1, 2).contiguous()
# return [x, 2*x, 3*x, 4*x, 5*x, 6*x, 7*x]
bev_feature = self.img_bev_encoder_backbone(x)
occ_bev_feature = self.img_bev_encoder_neck(bev_feature)
outs_occ = None
if self.wocc == True:
if self.uni_train == True:
if self.upsample:
occ_bev_feature = F.interpolate(occ_bev_feature, scale_factor=2,
mode='bilinear', align_corners=True)
outs_occ = self.occ_head(occ_bev_feature)
outs_det3d = None
if self.wdet3d == True:
outs_det3d = self.pts_bbox_head([occ_bev_feature])
outs = self.result_serialize(outs_det3d, outs_occ)
# outs_inst_center = self.aux_centerness_head([occ_bev_feature])
x = self.aux_centerness_head.shared_conv(occ_bev_feature) # (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
outs_inst_center_reg = self.aux_centerness_head.task_heads[0].reg(x)
outs.append(outs_inst_center_reg)
outs_inst_center_height = self.aux_centerness_head.task_heads[0].height(x)
outs.append(outs_inst_center_height)
outs_inst_center_heatmap = self.aux_centerness_head.task_heads[0].heatmap(x)
outs.append(outs_inst_center_heatmap)
return outs
def forward_with_argmax(
self,
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
mlp_input,
):
outs = self.forward_ori(
img,
ranks_depth,
ranks_feat,
ranks_bev,
interval_starts,
interval_lengths,
mlp_input,
)
pred_occ_label = outs[0].argmax(-1)
return pred_occ_label, *outs[1:]
def get_bev_pool_input(self, input):
input = self.prepare_inputs(input)
coor = self.img_view_transformer.get_lidar_coor(*input[1:7])
mlp_input = self.img_view_transformer.get_mlp_input(*input[1:7])
# sensor2keyegos, ego2globals, intrins, post_rots, post_trans, bda) # (B, N_views, 27)
return self.img_view_transformer.voxel_pooling_prepare_v2(coor), mlp_input
# Copyright (c) Phigent Robotics. All rights reserved.
import torch
import torch.nn.functional as F
from mmcv.runner import force_fp32
from mmdet3d.models import DETECTORS
from mmdet3d.models import builder
from .bevdepth4d import BEVDepth4D
from mmdet.models.backbones.resnet import ResNet
@DETECTORS.register_module()
class BEVStereo4D(BEVDepth4D):
def __init__(self, **kwargs):
super(BEVStereo4D, self).__init__(**kwargs)
self.extra_ref_frames = 1
self.temporal_frame = self.num_frame
self.num_frame += self.extra_ref_frames
def extract_stereo_ref_feat(self, x):
"""
Args:
x: (B, N_views, 3, H, W)
Returns:
x: (B*N_views, C_stereo, fH_stereo, fW_stereo)
"""
B, N, C, imH, imW = x.shape
x = x.view(B * N, C, imH, imW) # (B*N_views, 3, H, W)
if isinstance(self.img_backbone, ResNet):
if self.img_backbone.deep_stem:
x = self.img_backbone.stem(x)
else:
x = self.img_backbone.conv1(x)
x = self.img_backbone.norm1(x)
x = self.img_backbone.relu(x)
x = self.img_backbone.maxpool(x)
for i, layer_name in enumerate(self.img_backbone.res_layers):
res_layer = getattr(self.img_backbone, layer_name)
x = res_layer(x)
return x
else:
x = self.img_backbone.patch_embed(x)
hw_shape = (self.img_backbone.patch_embed.DH,
self.img_backbone.patch_embed.DW)
if self.img_backbone.use_abs_pos_embed:
x = x + self.img_backbone.absolute_pos_embed
x = self.img_backbone.drop_after_pos(x)
for i, stage in enumerate(self.img_backbone.stages):
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
out = out.view(-1, *out_hw_shape,
self.img_backbone.num_features[i])
out = out.permute(0, 3, 1, 2).contiguous()
return out
def prepare_bev_feat(self, img, sensor2keyego, ego2global, intrin,
post_rot, post_tran, bda, mlp_input, feat_prev_iv,
k2s_sensor, extra_ref_frame):
"""
Args:
img: (B, N_views, 3, H, W)
sensor2keyego: (B, N_views, 4, 4)
ego2global: (B, N_views, 4, 4)
intrin: (B, N_views, 3, 3)
post_rot: (B, N_views, 3, 3)
post_tran: (B, N_views, 3)
bda: (B, 3, 3)
mlp_input: (B, N_views, 27)
feat_prev_iv: (B*N_views, C_stereo, fH_stereo, fW_stereo) or None
k2s_sensor: (B, N_views, 4, 4) or None
extra_ref_frame:
Returns:
bev_feat: (B, C, Dy, Dx)
depth: (B*N, D, fH, fW)
stereo_feat: (B*N_views, C_stereo, fH_stereo, fW_stereo)
"""
if extra_ref_frame:
stereo_feat = self.extract_stereo_ref_feat(img) # (B*N_views, C_stereo, fH_stereo, fW_stereo)
return None, None, stereo_feat
# x: (B, N_views, C, fH, fW)
# stereo_feat: (B*N, C_stereo, fH_stereo, fW_stereo)
x, stereo_feat = self.image_encoder(img, stereo=True)
# 建立cost volume 所需的信息.
metas = dict(k2s_sensor=k2s_sensor, # (B, N_views, 4, 4)
intrins=intrin, # (B, N_views, 3, 3)
post_rots=post_rot, # (B, N_views, 3, 3)
post_trans=post_tran, # (B, N_views, 3)
frustum=self.img_view_transformer.cv_frustum.to(x), # (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample=4,
downsample=self.img_view_transformer.downsample,
grid_config=self.img_view_transformer.grid_config,
cv_feat_list=[feat_prev_iv, stereo_feat]
)
# bev_feat: (B, C * Dz(=1), Dy, Dx)
# depth: (B * N, D, fH, fW)
bev_feat, depth = self.img_view_transformer(
[x, sensor2keyego, ego2global, intrin, post_rot, post_tran, bda,
mlp_input], metas)
if self.pre_process:
bev_feat = self.pre_process_net(bev_feat)[0] # (B, C, Dy, Dx)
return bev_feat, depth, stereo_feat
def extract_img_feat_sequential(self, inputs, feat_prev):
"""
Args:
inputs:
curr_img: (1, N_views, 3, H, W)
sensor2keyegos_curr: (N_prev, N_views, 4, 4)
ego2globals_curr: (N_prev, N_views, 4, 4)
intrins: (1, N_views, 3, 3)
sensor2keyegos_prev: (N_prev, N_views, 4, 4)
ego2globals_prev: (N_prev, N_views, 4, 4)
post_rots: (1, N_views, 3, 3)
post_trans: (1, N_views, 3, )
bda_curr: (N_prev, 3, 3)
feat_prev_iv:
curr2adjsensor: (1, N_views, 4, 4)
feat_prev: (N_prev, C, Dy, Dx)
Returns:
"""
imgs, sensor2keyegos_curr, ego2globals_curr, intrins = inputs[:4]
sensor2keyegos_prev, _, post_rots, post_trans, bda = inputs[4:9]
feat_prev_iv, curr2adjsensor = inputs[9:]
bev_feat_list = []
mlp_input = self.img_view_transformer.get_mlp_input(
sensor2keyegos_curr[0:1, ...], ego2globals_curr[0:1, ...],
intrins, post_rots, post_trans, bda[0:1, ...])
inputs_curr = (imgs, sensor2keyegos_curr[0:1, ...],
ego2globals_curr[0:1, ...], intrins, post_rots,
post_trans, bda[0:1, ...], mlp_input, feat_prev_iv,
curr2adjsensor, False)
# (1, C, Dx, Dy), (1*N, D, fH, fW)
bev_feat, depth, _ = self.prepare_bev_feat(*inputs_curr)
bev_feat_list.append(bev_feat)
# align the feat_prev
_, C, H, W = feat_prev.shape
# feat_prev: (N_prev, C, Dy, Dx)
feat_prev = \
self.shift_feature(feat_prev, # (N_prev, C, Dy, Dx)
[sensor2keyegos_curr, # (N_prev, N_views, 4, 4)
sensor2keyegos_prev], # (N_prev, N_views, 4, 4)
bda # (N_prev, 3, 3)
)
bev_feat_list.append(feat_prev.view(1, (self.num_frame - 2) * C, H, W)) # (1, N_prev*C, Dy, Dx)
bev_feat = torch.cat(bev_feat_list, dim=1) # (1, N_frames*C, Dy, Dx)
x = self.bev_encoder(bev_feat)
return [x], depth
def extract_img_feat(self,
img_inputs,
img_metas,
pred_prev=False,
sequential=False,
**kwargs):
"""
Args:
img_inputs:
imgs: (B, N, 3, H, W) # N = 6 * (N_history + 1)
sensor2egos: (B, N, 4, 4)
ego2globals: (B, N, 4, 4)
intrins: (B, N, 3, 3)
post_rots: (B, N, 3, 3)
post_trans: (B, N, 3)
bda_rot: (B, 3, 3)
img_metas:
**kwargs:
Returns:
x: [(B, C', H', W'), ]
depth: (B*N_views, D, fH, fW)
"""
if sequential:
return self.extract_img_feat_sequential(img_inputs, kwargs['feat_prev'])
imgs, sensor2keyegos, ego2globals, intrins, post_rots, post_trans, \
bda, curr2adjsensor = self.prepare_inputs(img_inputs, stereo=True)
"""Extract features of images."""
bev_feat_list = []
depth_key_frame = None
feat_prev_iv = None
for fid in range(self.num_frame-1, -1, -1):
img, sensor2keyego, ego2global, intrin, post_rot, post_tran = \
imgs[fid], sensor2keyegos[fid], ego2globals[fid], intrins[fid], \
post_rots[fid], post_trans[fid]
key_frame = fid == 0
extra_ref_frame = fid == self.num_frame-self.extra_ref_frames
if key_frame or self.with_prev:
if self.align_after_view_transfromation:
sensor2keyego, ego2global = sensor2keyegos[0], ego2globals[0]
mlp_input = self.img_view_transformer.get_mlp_input(
sensor2keyegos[0], ego2globals[0], intrin,
post_rot, post_tran, bda) # (B, N_views, 27)
inputs_curr = (img, sensor2keyego, ego2global, intrin,
post_rot, post_tran, bda, mlp_input,
feat_prev_iv, curr2adjsensor[fid],
extra_ref_frame)
if key_frame:
bev_feat, depth, feat_curr_iv = \
self.prepare_bev_feat(*inputs_curr)
depth_key_frame = depth
else:
with torch.no_grad():
bev_feat, depth, feat_curr_iv = \
self.prepare_bev_feat(*inputs_curr)
if not extra_ref_frame:
bev_feat_list.append(bev_feat)
if not key_frame:
feat_prev_iv = feat_curr_iv
if pred_prev:
assert self.align_after_view_transfromation
assert sensor2keyegos[0].shape[0] == 1 # batch_size = 1
feat_prev = torch.cat(bev_feat_list[1:], dim=0)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
ego2globals_curr = \
ego2globals[0].repeat(self.num_frame - 2, 1, 1, 1)
# (1, N_views, 4, 4) --> (N_prev, N_views, 4, 4)
sensor2keyegos_curr = \
sensor2keyegos[0].repeat(self.num_frame - 2, 1, 1, 1)
ego2globals_prev = torch.cat(ego2globals[1:-1], dim=0) # (N_prev, N_views, 4, 4)
sensor2keyegos_prev = torch.cat(sensor2keyegos[1:-1], dim=0) # (N_prev, N_views, 4, 4)
bda_curr = bda.repeat(self.num_frame - 2, 1, 1) # (N_prev, 3, 3)
return feat_prev, [imgs[0], # (1, N_views, 3, H, W)
sensor2keyegos_curr, # (N_prev, N_views, 4, 4)
ego2globals_curr, # (N_prev, N_views, 4, 4)
intrins[0], # (1, N_views, 3, 3)
sensor2keyegos_prev, # (N_prev, N_views, 4, 4)
ego2globals_prev, # (N_prev, N_views, 4, 4)
post_rots[0], # (1, N_views, 3, 3)
post_trans[0], # (1, N_views, 3, )
bda_curr, # (N_prev, 3, 3)
feat_prev_iv,
curr2adjsensor[0]]
if not self.with_prev:
bev_feat_key = bev_feat_list[0]
if len(bev_feat_key.shape) == 4:
b, c, h, w = bev_feat_key.shape
bev_feat_list = \
[torch.zeros([b,
c * (self.num_frame -
self.extra_ref_frames - 1),
h, w]).to(bev_feat_key), bev_feat_key]
else:
b, c, z, h, w = bev_feat_key.shape
bev_feat_list = \
[torch.zeros([b,
c * (self.num_frame -
self.extra_ref_frames - 1), z,
h, w]).to(bev_feat_key), bev_feat_key]
if self.align_after_view_transfromation:
for adj_id in range(self.num_frame-2):
bev_feat_list[adj_id] = self.shift_feature(
bev_feat_list[adj_id], # (B, C, Dy, Dx)
[sensor2keyegos[0], # (B, N_views, 4, 4)
sensor2keyegos[self.num_frame-2-adj_id]], # (B, N_views, 4, 4)
bda # (B, 3, 3)
) # (B, C, Dy, Dx)
bev_feat = torch.cat(bev_feat_list, dim=1)
x = self.bev_encoder(bev_feat)
return [x], depth_key_frame
from .cross_entropy_loss import CrossEntropyLoss
from .focal_loss import CustomFocalLoss
__all__ = ['CrossEntropyLoss', 'CustomFocalLoss']
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weight_reduce_loss
def cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index
# element-wise losses
loss = F.cross_entropy(
pred,
label,
weight=class_weight,
reduction='none',
ignore_index=ignore_index)
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = label.numel() - (label == ignore_index).sum().item()
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(
valid_mask & (labels < label_channels), as_tuple=False)
if inds.numel() > 0:
bin_labels[inds, labels[inds]] = 1
valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
label_channels).float()
if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
bin_label_weights *= valid_mask
return bin_labels, bin_label_weights, valid_mask
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1) or (N, ).
When the shape of pred is (N, 1), label will be expanded to
one-hot format, and when the shape of pred is (N, ), label
will not be expanded to one-hot format.
label (torch.Tensor): The learning label of the prediction,
with shape (N, ).
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int | None): The label index to be ignored.
If None, it will be set to default value. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
Returns:
torch.Tensor: The calculated loss.
"""
# The default value of ignore_index is the same as F.cross_entropy
ignore_index = -100 if ignore_index is None else ignore_index
if pred.dim() != label.dim():
label, weight, valid_mask = _expand_onehot_labels(
label, weight, pred.size(-1), ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
# The inplace writing method will have a mismatched broadcast
# shape error if the weight and valid_mask dimensions
# are inconsistent such as (B,N,1) and (B,N,C).
weight = weight * valid_mask
else:
weight = valid_mask
# average loss over non-ignored elements
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':
avg_factor = valid_mask.sum().item()
# weighted element-wise losses
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(
pred, label.float(), pos_weight=class_weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(
loss, weight, reduction=reduction, avg_factor=avg_factor)
return loss
def mask_cross_entropy(pred,
target,
label,
reduction='mean',
avg_factor=None,
class_weight=None,
ignore_index=None,
**kwargs):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C, *), C is the
number of classes. The trailing * indicates arbitrary shape.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
Example:
>>> N, C = 3, 11
>>> H, W = 2, 2
>>> pred = torch.randn(N, C, H, W) * 1000
>>> target = torch.rand(N, H, W)
>>> label = torch.randint(0, C, size=(N,))
>>> reduction = 'mean'
>>> avg_factor = None
>>> class_weights = None
>>> loss = mask_cross_entropy(pred, target, label, reduction,
>>> avg_factor, class_weights)
>>> assert loss.shape == (1,)
"""
assert ignore_index is None, 'BCE loss does not support ignore_index'
# TODO: handle these two reserved arguments
assert reduction == 'mean' and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, weight=class_weight, reduction='mean')[None]
@LOSSES.register_module(force=True)
class CrossEntropyLoss(nn.Module):
def __init__(self,
use_sigmoid=False,
use_mask=False,
reduction='mean',
class_weight=None,
ignore_index=None,
loss_weight=1.0,
avg_non_ignore=False):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float], optional): Weight of each class.
Defaults to None.
ignore_index (int | None): The label index to be ignored.
Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
"""
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.ignore_index = ignore_index
self.avg_non_ignore = avg_non_ignore
if ((ignore_index is not None) and not self.avg_non_ignore
and self.reduction == 'mean'):
warnings.warn(
'Default ``avg_non_ignore`` is False, if you would like to '
'ignore the certain label and average loss over non-ignore '
'labels, which is the same with PyTorch official '
'cross_entropy, set ``avg_non_ignore=True``.')
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
def extra_repr(self):
"""Extra repr."""
s = f'avg_non_ignore={self.avg_non_ignore}'
return s
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
ignore_index=None,
**kwargs):
"""Forward function.
Args:
cls_score (torch.Tensor): The prediction.
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The method used to reduce the
loss. Options are "none", "mean" and "sum".
ignore_index (int | None): The label index to be ignored.
If not None, it will override the default value. Default: None.
Returns:
torch.Tensor: The calculated loss.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if ignore_index is None:
ignore_index = self.ignore_index
if self.class_weight is not None:
class_weight = cls_score.new_tensor(
self.class_weight, device=cls_score.device)
else:
class_weight = None
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
ignore_index=ignore_index,
avg_non_ignore=self.avg_non_ignore,
**kwargs)
return loss_cls
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weight_reduce_loss
import numpy as np
# This method is only for debugging
def py_sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the
number of classes
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = loss * weight
loss = loss.sum(-1).mean()
# loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
def py_focal_loss_with_prob(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
"""PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
Different from `py_sigmoid_focal_loss`, this function accepts probability
as input.
Args:
pred (torch.Tensor): The prediction probability with shape (N, C),
C is the number of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
target = target[:, :num_classes]
target = target.type_as(pred)
pt = (1 - pred) * target + pred * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy(
pred, target, reduction='none') * focal_weight
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
def sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
r"""A wrapper of cuda version `Focal Loss
<https://arxiv.org/abs/1708.02002>`_.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma,
alpha, None, 'none')
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
# For most cases, weight is of shape (num_priors, ),
# which means it does not have the second axis num_class
weight = weight.view(-1, 1)
else:
# Sometimes, weight per anchor per class is also needed. e.g.
# in FSAF. But it may be flattened of shape
# (num_priors x num_class, ), while loss is still of shape
# (num_priors, num_class).
assert weight.numel() == loss.numel()
weight = weight.view(loss.size(0), -1)
assert weight.ndim == loss.ndim
loss = loss * weight
loss = loss.sum(-1).mean()
# loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module()
class CustomFocalLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
reduction='mean',
loss_weight=100.0,
activated=False):
"""`Focal Loss <https://arxiv.org/abs/1708.02002>`_
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
gamma (float, optional): The gamma for calculating the modulating
factor. Defaults to 2.0.
alpha (float, optional): A balanced form for Focal Loss.
Defaults to 0.25.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'. Options are "none", "mean" and
"sum".
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
activated (bool, optional): Whether the input is activated.
If True, it means the input has been activated and can be
treated as probabilities. Else, it should be treated as logits.
Defaults to False.
"""
super(CustomFocalLoss, self).__init__()
assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
self.use_sigmoid = use_sigmoid
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
self.activated = activated
H, W = 200, 200
xy, yx = torch.meshgrid([torch.arange(H) - H / 2, torch.arange(W) - W / 2])
c = torch.stack([xy, yx], 2)
c = torch.norm(c, 2, -1)
c_max = c.max()
self.c = (c / c_max + 1).cuda()
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
ignore_index=255,
reduction_override=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
B, H, W, D = target.shape
c = self.c[None, :, :, None].repeat(B, 1, 1, D).reshape(-1)
visible_mask = (target != ignore_index).reshape(-1).nonzero().squeeze(-1)
weight_mask = weight[None, :] * c[visible_mask, None]
# visible_mask[:, None]
num_classes = pred.size(1)
pred = pred.permute(0, 2, 3, 4, 1).reshape(-1, num_classes)[visible_mask]
target = target.reshape(-1)[visible_mask]
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.use_sigmoid:
if self.activated:
calculate_loss_func = py_focal_loss_with_prob
else:
if torch.cuda.is_available() and pred.is_cuda:
calculate_loss_func = sigmoid_focal_loss
else:
num_classes = pred.size(1)
target = F.one_hot(target, num_classes=num_classes + 1)
target = target[:, :num_classes]
calculate_loss_func = py_sigmoid_focal_loss
loss_cls = self.loss_weight * calculate_loss_func(
pred,
target.to(torch.long),
weight_mask,
gamma=self.gamma,
alpha=self.alpha,
reduction=reduction,
avg_factor=avg_factor)
else:
raise NotImplementedError
return loss_cls
# -*- coding:utf-8 -*-
# author: Xinge
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""
from __future__ import print_function, division
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
try:
from itertools import ifilterfalse
except ImportError: # py3k
from itertools import filterfalse as ifilterfalse
from torch.cuda.amp import autocast
def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
"""
IoU for foreground class
binary: 1 foreground, 0 background
"""
if not per_image:
preds, labels = (preds,), (labels,)
ious = []
for pred, label in zip(preds, labels):
intersection = ((label == 1) & (pred == 1)).sum()
union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
if not union:
iou = EMPTY
else:
iou = float(intersection) / float(union)
ious.append(iou)
iou = mean(ious) # mean accross images if per_image
return 100 * iou
def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
"""
Array of IoU for each (non ignored) class
"""
if not per_image:
preds, labels = (preds,), (labels,)
ious = []
for pred, label in zip(preds, labels):
iou = []
for i in range(C):
if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
intersection = ((label == i) & (pred == i)).sum()
union = ((label == i) | ((pred == i) & (label != ignore))).sum()
if not union:
iou.append(EMPTY)
else:
iou.append(float(intersection) / float(union))
ious.append(iou)
ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
return 100 * np.array(ious)
# --------------------------- BINARY LOSSES ---------------------------
def lovasz_hinge(logits, labels, per_image=True, ignore=None):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
for log, lab in zip(logits, labels))
else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
return loss
def lovasz_hinge_flat(logits, labels):
"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""
if len(labels) == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * Variable(signs))
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
return loss
def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = scores.view(-1)
labels = labels.view(-1)
if ignore is None:
return scores, labels
valid = (labels != ignore)
vscores = scores[valid]
vlabels = labels[valid]
return vscores, vlabels
class StableBCELoss(torch.nn.modules.Module):
def __init__(self):
super(StableBCELoss, self).__init__()
def forward(self, input, target):
neg_abs = - input.abs()
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
return loss.mean()
def binary_xloss(logits, labels, ignore=None):
"""
Binary Cross entropy loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
ignore: void class id
"""
logits, labels = flatten_binary_scores(logits, labels, ignore)
loss = StableBCELoss()(logits, Variable(labels.float()))
return loss
# --------------------------- MULTICLASS LOSSES ---------------------------
def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
"""
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
if per_image:
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
for prob, lab in zip(probas, labels))
else:
with autocast(False):
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
return loss
def lovasz_softmax_flat(probas, labels, classes='present'):
"""
Multi-class Lovasz-Softmax loss
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
if probas.numel() == 0:
# only void pixels, the gradients should be 0
return probas * 0.
C = probas.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
fg = (labels == c).float() # foreground for class c
if (classes is 'present' and fg.sum() == 0):
continue
if C == 1:
if len(classes) > 1:
raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probas[:, 0]
else:
class_pred = probas[:, c]
errors = (Variable(fg) - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
return mean(losses)
def flatten_probas(probas, labels, ignore=None):
"""
Flattens predictions in the batch
"""
if probas.dim() == 2:
if ignore is not None:
valid = (labels != ignore)
probas = probas[valid]
labels = labels[valid]
return probas, labels
elif probas.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probas.size()
probas = probas.view(B, 1, H, W)
elif probas.dim() == 5:
#3D segmentation
B, C, L, H, W = probas.size()
probas = probas.contiguous().view(B, C, L, H*W)
B, C, H, W = probas.size()
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
labels = labels.view(-1)
if ignore is None:
return probas, labels
valid = (labels != ignore)
vprobas = probas[valid.nonzero().squeeze()]
vlabels = labels[valid]
return vprobas, vlabels
def xloss(logits, labels, ignore=None):
"""
Cross entropy loss
"""
return F.cross_entropy(logits, Variable(labels), ignore_index=255)
def jaccard_loss(probas, labels,ignore=None, smooth = 100, bk_class = None):
"""
Something wrong with this loss
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
vprobas, vlabels = flatten_probas(probas, labels, ignore)
true_1_hot = torch.eye(vprobas.shape[1])[vlabels]
if bk_class:
one_hot_assignment = torch.ones_like(vlabels)
one_hot_assignment[vlabels == bk_class] = 0
one_hot_assignment = one_hot_assignment.float().unsqueeze(1)
true_1_hot = true_1_hot*one_hot_assignment
true_1_hot = true_1_hot.to(vprobas.device)
intersection = torch.sum(vprobas * true_1_hot)
cardinality = torch.sum(vprobas + true_1_hot)
loss = (intersection + smooth / (cardinality - intersection + smooth)).mean()
return (1-loss)*smooth
def hinge_jaccard_loss(probas, labels,ignore=None, classes = 'present', hinge = 0.1, smooth =100):
"""
Multi-class Hinge Jaccard loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
ignore: void class labels
"""
vprobas, vlabels = flatten_probas(probas, labels, ignore)
C = vprobas.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
if c in vlabels:
c_sample_ind = vlabels == c
cprobas = vprobas[c_sample_ind,:]
non_c_ind =np.array([a for a in class_to_sum if a != c])
class_pred = cprobas[:,c]
max_non_class_pred = torch.max(cprobas[:,non_c_ind],dim = 1)[0]
TP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) + smooth
FN = torch.sum(torch.clamp(max_non_class_pred - class_pred, min = -hinge)+hinge)
if (~c_sample_ind).sum() == 0:
FP = 0
else:
nonc_probas = vprobas[~c_sample_ind,:]
class_pred = nonc_probas[:,c]
max_non_class_pred = torch.max(nonc_probas[:,non_c_ind],dim = 1)[0]
FP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.)
losses.append(1 - TP/(TP+FP+FN))
if len(losses) == 0: return 0
return mean(losses)
# --------------------------- HELPER FUNCTIONS ---------------------------
def isnan(x):
return x != x
def mean(l, ignore_nan=False, empty=0):
"""
nanmean compatible with generators.
"""
l = iter(l)
if ignore_nan:
l = ifilterfalse(isnan, l)
try:
n = 1
acc = next(l)
except StopIteration:
if empty == 'raise':
raise ValueError('Empty mean')
return empty
for n, v in enumerate(l, 2):
acc += v
if n == 1:
return acc
return acc / n
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# from mmcv.runner import BaseModule, force_fp32
from torch.cuda.amp import autocast
semantic_kitti_class_frequencies = np.array(
[
5.41773033e09,
1.57835390e07,
1.25136000e05,
1.18809000e05,
6.46799000e05,
8.21951000e05,
2.62978000e05,
2.83696000e05,
2.04750000e05,
6.16887030e07,
4.50296100e06,
4.48836500e07,
2.26992300e06,
5.68402180e07,
1.57196520e07,
1.58442623e08,
2.06162300e06,
3.69705220e07,
1.15198800e06,
3.34146000e05,
]
)
kitti_class_names = [
"empty",
"car",
"bicycle",
"motorcycle",
"truck",
"other-vehicle",
"person",
"bicyclist",
"motorcyclist",
"road",
"parking",
"sidewalk",
"other-ground",
"building",
"fence",
"vegetation",
"trunk",
"terrain",
"pole",
"traffic-sign",
]
def inverse_sigmoid(x, sign='A'):
x = x.to(torch.float32)
while x >= 1-1e-5:
x = x - 1e-5
while x< 1e-5:
x = x + 1e-5
return -torch.log((1 / x) - 1)
def KL_sep(p, target):
"""
KL divergence on nonzeros classes
"""
nonzeros = target != 0
nonzero_p = p[nonzeros]
kl_term = F.kl_div(torch.log(nonzero_p), target[nonzeros], reduction="sum")
return kl_term
def geo_scal_loss(pred, ssc_target, ignore_index=255, non_empty_idx=0):
# Get softmax probabilities
pred = F.softmax(pred, dim=1)
# Compute empty and nonempty probabilities
empty_probs = pred[:, non_empty_idx]
nonempty_probs = 1 - empty_probs
# Remove unknown voxels
mask = ssc_target != ignore_index
nonempty_target = ssc_target != non_empty_idx
nonempty_target = nonempty_target[mask].float()
nonempty_probs = nonempty_probs[mask]
empty_probs = empty_probs[mask]
eps = 1e-5
intersection = (nonempty_target * nonempty_probs).sum()
precision = intersection / (nonempty_probs.sum()+eps)
recall = intersection / (nonempty_target.sum()+eps)
spec = ((1 - nonempty_target) * (empty_probs)).sum() / ((1 - nonempty_target).sum()+eps)
with autocast(False):
return (
F.binary_cross_entropy_with_logits(inverse_sigmoid(precision, 'A'), torch.ones_like(precision))
+ F.binary_cross_entropy_with_logits(inverse_sigmoid(recall, 'B'), torch.ones_like(recall))
+ F.binary_cross_entropy_with_logits(inverse_sigmoid(spec, 'C'), torch.ones_like(spec))
)
def sem_scal_loss(pred_, ssc_target, ignore_index=255):
# Get softmax probabilities
with autocast(False):
pred = F.softmax(pred_, dim=1) # (B, n_class, Dx, Dy, Dz)
loss = 0
count = 0
mask = ssc_target != ignore_index
n_classes = pred.shape[1]
begin = 0
for i in range(begin, n_classes-1):
# Get probability of class i
p = pred[:, i] # (B, Dx, Dy, Dz)
# Remove unknown voxels
target_ori = ssc_target # (B, Dx, Dy, Dz)
p = p[mask]
target = ssc_target[mask]
completion_target = torch.ones_like(target)
completion_target[target != i] = 0
completion_target_ori = torch.ones_like(target_ori).float()
completion_target_ori[target_ori != i] = 0
if torch.sum(completion_target) > 0:
count += 1.0
nominator = torch.sum(p * completion_target)
loss_class = 0
if torch.sum(p) > 0:
precision = nominator / (torch.sum(p)+ 1e-5)
loss_precision = F.binary_cross_entropy_with_logits(
inverse_sigmoid(precision, 'D'), torch.ones_like(precision)
)
loss_class += loss_precision
if torch.sum(completion_target) > 0:
recall = nominator / (torch.sum(completion_target) +1e-5)
# loss_recall = F.binary_cross_entropy(recall, torch.ones_like(recall))
loss_recall = F.binary_cross_entropy_with_logits(inverse_sigmoid(recall, 'E'), torch.ones_like(recall))
loss_class += loss_recall
if torch.sum(1 - completion_target) > 0:
specificity = torch.sum((1 - p) * (1 - completion_target)) / (
torch.sum(1 - completion_target) + 1e-5
)
loss_specificity = F.binary_cross_entropy_with_logits(
inverse_sigmoid(specificity, 'F'), torch.ones_like(specificity)
)
loss_class += loss_specificity
loss += loss_class
# print(i, loss_class, loss_recall, loss_specificity)
l = loss/count
if torch.isnan(l):
from IPython import embed
embed()
exit()
return l
def CE_ssc_loss(pred, target, class_weights=None, ignore_index=255):
"""
:param: prediction: the predicted tensor, must be [BS, C, ...]
"""
criterion = nn.CrossEntropyLoss(
weight=class_weights, ignore_index=ignore_index, reduction="mean"
)
# from IPython import embed
# embed()
# exit()
with autocast(False):
loss = criterion(pred, target.long())
return loss
def vel_loss(pred, gt):
with autocast(False):
return F.l1_loss(pred, gt)
from .depthnet import DepthNet
__all__ = ['DepthNet']
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models.backbones.resnet import BasicBlock
from mmcv.cnn import build_conv_layer
from torch.cuda.amp.autocast_mode import autocast
from torch.utils.checkpoint import checkpoint
class _ASPPModule(nn.Module):
def __init__(self, inplanes, planes, kernel_size, padding, dilation,
BatchNorm):
super(_ASPPModule, self).__init__()
self.atrous_conv = nn.Conv2d(
inplanes,
planes,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=False)
self.bn = BatchNorm(planes)
self.relu = nn.ReLU()
self._init_weight()
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class ASPP(nn.Module):
def __init__(self, inplanes, mid_channels=256, BatchNorm=nn.BatchNorm2d):
super(ASPP, self).__init__()
dilations = [1, 6, 12, 18]
self.aspp1 = _ASPPModule(
inplanes,
mid_channels,
1,
padding=0,
dilation=dilations[0],
BatchNorm=BatchNorm)
self.aspp2 = _ASPPModule(
inplanes,
mid_channels,
3,
padding=dilations[1],
dilation=dilations[1],
BatchNorm=BatchNorm)
self.aspp3 = _ASPPModule(
inplanes,
mid_channels,
3,
padding=dilations[2],
dilation=dilations[2],
BatchNorm=BatchNorm)
self.aspp4 = _ASPPModule(
inplanes,
mid_channels,
3,
padding=dilations[3],
dilation=dilations[3],
BatchNorm=BatchNorm)
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(inplanes, mid_channels, 1, stride=1, bias=False),
BatchNorm(mid_channels),
nn.ReLU(),
)
self.conv1 = nn.Conv2d(
int(mid_channels * 5), inplanes, 1, bias=False)
self.bn1 = BatchNorm(inplanes)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self._init_weight()
def forward(self, x):
"""
Args:
x: (B*N, C, fH, fW)
Returns:
x: (B*N, C, fH, fW)
"""
x1 = self.aspp1(x)
x2 = self.aspp2(x)
x3 = self.aspp3(x)
x4 = self.aspp4(x)
x5 = self.global_avg_pool(x)
x5 = F.interpolate(
x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, x2, x3, x4, x5), dim=1) # (B*N, 5*C', fH, fW)
x = self.conv1(x) # (B*N, C, fH, fW)
x = self.bn1(x)
x = self.relu(x)
return self.dropout(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.ReLU,
drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
"""
Args:
x: (B*N_views, 27)
Returns:
x: (B*N_views, C)
"""
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class SELayer(nn.Module):
def __init__(self, channels, act_layer=nn.ReLU, gate_layer=nn.Sigmoid):
super().__init__()
self.conv_reduce = nn.Conv2d(channels, channels, 1, bias=True)
self.act1 = act_layer()
self.conv_expand = nn.Conv2d(channels, channels, 1, bias=True)
self.gate = gate_layer()
def forward(self, x, x_se):
"""
Args:
x: (B*N_views, C_mid, fH, fW)
x_se: (B*N_views, C_mid, 1, 1)
Returns:
x: (B*N_views, C_mid, fH, fW)
"""
x_se = self.conv_reduce(x_se) # (B*N_views, C_mid, 1, 1)
x_se = self.act1(x_se) # (B*N_views, C_mid, 1, 1)
x_se = self.conv_expand(x_se) # (B*N_views, C_mid, 1, 1)
return x * self.gate(x_se) # (B*N_views, C_mid, fH, fW)
class DepthNet(nn.Module):
def __init__(self,
in_channels,
mid_channels,
context_channels,
depth_channels,
use_dcn=True,
use_aspp=True,
with_cp=False,
stereo=False,
bias=0.0,
aspp_mid_channels=-1):
super(DepthNet, self).__init__()
self.reduce_conv = nn.Sequential(
nn.Conv2d(
in_channels, mid_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
# 生成context feature
self.context_conv = nn.Conv2d(
mid_channels, context_channels, kernel_size=1, stride=1, padding=0)
self.bn = nn.BatchNorm1d(27)
self.depth_mlp = Mlp(in_features=27, hidden_features=mid_channels, out_features=mid_channels)
self.depth_se = SELayer(channels=mid_channels) # NOTE: add camera-aware
self.context_mlp = Mlp(in_features=27, hidden_features=mid_channels, out_features=mid_channels)
self.context_se = SELayer(channels=mid_channels) # NOTE: add camera-aware
depth_conv_input_channels = mid_channels
downsample = None
if stereo:
depth_conv_input_channels += depth_channels
downsample = nn.Conv2d(depth_conv_input_channels,
mid_channels, 1, 1, 0)
cost_volumn_net = []
for stage in range(int(2)):
cost_volumn_net.extend([
nn.Conv2d(depth_channels, depth_channels, kernel_size=3,
stride=2, padding=1),
nn.BatchNorm2d(depth_channels)])
self.cost_volumn_net = nn.Sequential(*cost_volumn_net)
self.bias = bias
# 3个残差blocks
depth_conv_list = [BasicBlock(depth_conv_input_channels, mid_channels,
downsample=downsample),
BasicBlock(mid_channels, mid_channels),
BasicBlock(mid_channels, mid_channels)]
if use_aspp:
if aspp_mid_channels < 0:
aspp_mid_channels = mid_channels
depth_conv_list.append(ASPP(mid_channels, aspp_mid_channels))
if use_dcn:
depth_conv_list.append(
build_conv_layer(
cfg=dict(
type='DCN',
in_channels=mid_channels,
out_channels=mid_channels,
kernel_size=3,
padding=1,
groups=4,
im2col_step=128,
)))
depth_conv_list.append(
nn.Conv2d(
mid_channels,
depth_channels,
kernel_size=1,
stride=1,
padding=0))
self.depth_conv = nn.Sequential(*depth_conv_list)
self.with_cp = with_cp
self.depth_channels = depth_channels
# ----------------------------------------- 用于建立cost volume ----------------------------------
def gen_grid(self, metas, B, N, D, H, W, hi, wi):
"""
Args:
metas: dict{
k2s_sensor: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
frustum: (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample: 4,
downsample: self.img_view_transformer.downsample=16,
grid_config: self.img_view_transformer.grid_config,
cv_feat_list: [feat_prev_iv, stereo_feat]
}
B: batchsize
N: N_views
D: D
H: fH_stereo
W: fW_stereo
hi: H_img
wi: W_img
Returns:
grid: (B*N_views, D*fH_stereo, fW_stereo, 2)
"""
frustum = metas['frustum'] # (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
# 逆图像增广:
points = frustum - metas['post_trans'].view(B, N, 1, 1, 1, 3)
points = torch.inverse(metas['post_rots']).view(B, N, 1, 1, 1, 3, 3) \
.matmul(points.unsqueeze(-1)) # (B, N_views, D, fH_stereo, fW_stereo, 3, 1)
# (u, v, d) --> (du, dv, d)
# (B, N_views, D, fH_stereo, fW_stereo, 3, 1)
points = torch.cat(
(points[..., :2, :] * points[..., 2:3, :], points[..., 2:3, :]), 5)
# cur_pixel --> curr_camera --> prev_camera
rots = metas['k2s_sensor'][:, :, :3, :3].contiguous()
trans = metas['k2s_sensor'][:, :, :3, 3].contiguous()
combine = rots.matmul(torch.inverse(metas['intrins']))
points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points)
points += trans.view(B, N, 1, 1, 1, 3, 1) # (B, N_views, D, fH_stereo, fW_stereo, 3, 1)
neg_mask = points[..., 2, 0] < 1e-3
# prev_camera --> prev_pixel
points = metas['intrins'].view(B, N, 1, 1, 1, 3, 3).matmul(points)
# (du, dv, d) --> (u, v) (B, N_views, D, fH_stereo, fW_stereo, 2, 1)
points = points[..., :2, :] / points[..., 2:3, :]
# 图像增广
points = metas['post_rots'][..., :2, :2].view(B, N, 1, 1, 1, 2, 2).matmul(
points).squeeze(-1)
points += metas['post_trans'][..., :2].view(B, N, 1, 1, 1, 2) # (B, N_views, D, fH_stereo, fW_stereo, 2)
px = points[..., 0] / (wi - 1.0) * 2.0 - 1.0
py = points[..., 1] / (hi - 1.0) * 2.0 - 1.0
px[neg_mask] = -2
py[neg_mask] = -2
grid = torch.stack([px, py], dim=-1) # (B, N_views, D, fH_stereo, fW_stereo, 2)
grid = grid.view(B * N, D * H, W, 2) # (B*N_views, D*fH_stereo, fW_stereo, 2)
return grid
def calculate_cost_volumn(self, metas):
"""
Args:
metas: dict{
k2s_sensor: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
frustum: (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample: 4,
downsample: self.img_view_transformer.downsample=16,
grid_config: self.img_view_transformer.grid_config,
cv_feat_list: [feat_prev_iv, stereo_feat]
}
Returns:
cost_volumn: (B*N_views, D, fH_stereo, fW_stereo)
"""
prev, curr = metas['cv_feat_list'] # (B*N_views, C_stereo, fH_stereo, fW_stereo)
group_size = 4
_, c, hf, wf = curr.shape #
hi, wi = hf * 4, wf * 4 # H_img, W_img
B, N, _ = metas['post_trans'].shape
D, H, W, _ = metas['frustum'].shape
grid = self.gen_grid(metas, B, N, D, H, W, hi, wi).to(curr.dtype) # (B*N_views, D*fH_stereo, fW_stereo, 2)
prev = prev.view(B * N, -1, H, W) # (B*N_views, C_stereo, fH_stereo, fW_stereo)
curr = curr.view(B * N, -1, H, W) # (B*N_views, C_stereo, fH_stereo, fW_stereo)
cost_volumn = 0
# process in group wise to save memory
for fid in range(curr.shape[1] // group_size):
# (B*N_views, group_size, fH_stereo, fW_stereo)
prev_curr = prev[:, fid * group_size:(fid + 1) * group_size, ...]
wrap_prev = F.grid_sample(prev_curr, grid,
align_corners=True,
padding_mode='zeros') # (B*N_views, group_size, D*fH_stereo, fW_stereo)
# (B*N_views, group_size, fH_stereo, fW_stereo)
curr_tmp = curr[:, fid * group_size:(fid + 1) * group_size, ...]
# (B*N_views, group_size, 1, fH_stereo, fW_stereo) - (B*N_views, group_size, D, fH_stereo, fW_stereo)
# --> (B*N_views, group_size, D, fH_stereo, fW_stereo)
# https://github.com/HuangJunJie2017/BEVDet/issues/278
cost_volumn_tmp = curr_tmp.unsqueeze(2) - \
wrap_prev.view(B * N, -1, D, H, W)
cost_volumn_tmp = cost_volumn_tmp.abs().sum(dim=1) # (B*N_views, D, fH_stereo, fW_stereo)
cost_volumn += cost_volumn_tmp # (B*N_views, D, fH_stereo, fW_stereo)
if not self.bias == 0:
invalid = wrap_prev[:, 0, ...].view(B * N, D, H, W) == 0
cost_volumn[invalid] = cost_volumn[invalid] + self.bias
# matching cost --> prob
cost_volumn = - cost_volumn
cost_volumn = cost_volumn.softmax(dim=1)
return cost_volumn
# ----------------------------------------- 用于建立cost volume --------------------------------------
def forward(self, x, mlp_input, stereo_metas=None):
"""
Args:
x: (B*N_views, C, fH, fW)
mlp_input: (B, N_views, 27)
stereo_metas: None or dict{
k2s_sensor: (B, N_views, 4, 4)
intrins: (B, N_views, 3, 3)
post_rots: (B, N_views, 3, 3)
post_trans: (B, N_views, 3)
frustum: (D, fH_stereo, fW_stereo, 3) 3:(u, v, d)
cv_downsample: 4,
downsample: self.img_view_transformer.downsample=16,
grid_config: self.img_view_transformer.grid_config,
cv_feat_list: [feat_prev_iv, stereo_feat]
}
Returns:
x: (B*N_views, D+C_context, fH, fW)
"""
mlp_input = self.bn(mlp_input.reshape(-1, mlp_input.shape[-1])) # (B*N_views, 27)
x = self.reduce_conv(x) # (B*N_views, C_mid, fH, fW)
# (B*N_views, 27) --> (B*N_views, C_mid) --> (B*N_views, C_mid, 1, 1)
context_se = self.context_mlp(mlp_input)[..., None, None]
context = self.context_se(x, context_se) # (B*N_views, C_mid, fH, fW)
context = self.context_conv(context) # (B*N_views, C_context, fH, fW)
# (B*N_views, 27) --> (B*N_views, C_mid) --> (B*N_views, C_mid, 1, 1)
depth_se = self.depth_mlp(mlp_input)[..., None, None]
depth = self.depth_se(x, depth_se) # (B*N_views, C_mid, fH, fW)
if not stereo_metas is None:
if stereo_metas['cv_feat_list'][0] is None:
BN, _, H, W = x.shape
scale_factor = float(stereo_metas['downsample'])/\
stereo_metas['cv_downsample']
cost_volumn = \
torch.zeros((BN, self.depth_channels,
int(H*scale_factor),
int(W*scale_factor))).to(x)
else:
with torch.no_grad():
# https://github.com/HuangJunJie2017/BEVDet/issues/278
cost_volumn = self.calculate_cost_volumn(stereo_metas) # (B*N_views, D, fH_stereo, fW_stereo)
cost_volumn = self.cost_volumn_net(cost_volumn) # (B*N_views, D, fH, fW)
depth = torch.cat([depth, cost_volumn], dim=1) # (B*N_views, C_mid+D, fH, fW)
if self.with_cp:
depth = checkpoint(self.depth_conv, depth)
else:
# 3*res blocks +ASPP/DCN + Conv(c_mid-->D)
depth = self.depth_conv(depth) # x: (B*N_views, C_mid, fH, fW) --> (B*N_views, D, fH, fW)
return torch.cat([depth, context], dim=1)
class DepthAggregation(nn.Module):
"""pixel cloud feature extraction."""
def __init__(self, in_channels, mid_channels, out_channels):
super(DepthAggregation, self).__init__()
self.reduce_conv = nn.Sequential(
nn.Conv2d(
in_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
self.conv = nn.Sequential(
nn.Conv2d(
mid_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(
mid_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
)
self.out_conv = nn.Sequential(
nn.Conv2d(
mid_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True),
# nn.BatchNorm3d(out_channels),
# nn.ReLU(inplace=True),
)
@autocast(False)
def forward(self, x):
x = checkpoint(self.reduce_conv, x)
short_cut = x
x = checkpoint(self.conv, x)
x = short_cut + x
x = self.out_conv(x)
return x
\ No newline at end of file
from .fpn import CustomFPN
from .view_transformer import LSSViewTransformer, LSSViewTransformerBEVDepth, LSSViewTransformerBEVStereo
from .lss_fpn import FPN_LSS
__all__ = ['CustomFPN', 'FPN_LSS', 'LSSViewTransformer', 'LSSViewTransformerBEVDepth', 'LSSViewTransformerBEVStereo']
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment