Commit d2b71343 authored by 雍大凯's avatar 雍大凯
Browse files

add code

parent 69e57885
from mmdet.models.backbones import ResNet
from .resnet import CustomResNet
from .swin import SwinTransformer
__all__ = ['ResNet', 'CustomResNet', 'SwinTransformer']
# Copyright (c) Phigent Robotics. All rights reserved.
import torch.utils.checkpoint as checkpoint
from torch import nn
from mmcv.cnn.bricks.conv_module import ConvModule
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
from mmdet3d.models import BACKBONES
@BACKBONES.register_module()
class CustomResNet(nn.Module):
def __init__(
self,
numC_input,
num_layer=[2, 2, 2],
num_channels=None,
stride=[2, 2, 2],
backbone_output_ids=None,
norm_cfg=dict(type='BN'),
with_cp=False,
block_type='Basic',
):
super(CustomResNet, self).__init__()
# build backbone
assert len(num_layer) == len(stride)
num_channels = [numC_input*2**(i+1) for i in range(len(num_layer))] \
if num_channels is None else num_channels
self.backbone_output_ids = range(len(num_layer)) \
if backbone_output_ids is None else backbone_output_ids
layers = []
if block_type == 'BottleNeck':
curr_numC = numC_input
for i in range(len(num_layer)):
# 在第一个block中对输入进行downsample
layer = [Bottleneck(inplanes=curr_numC, planes=num_channels[i]//4, stride=stride[i],
downsample=nn.Conv2d(curr_numC, num_channels[i], 3, stride[i], 1),
norm_cfg=norm_cfg)]
curr_numC = num_channels[i]
layer.extend([Bottleneck(inplanes=curr_numC, planes=num_channels[i]//4, stride=1,
downsample=None, norm_cfg=norm_cfg) for _ in range(num_layer[i] - 1)])
layers.append(nn.Sequential(*layer))
elif block_type == 'Basic':
curr_numC = numC_input
for i in range(len(num_layer)):
# 在第一个block中对输入进行downsample
layer = [BasicBlock(inplanes=curr_numC, planes=num_channels[i], stride=stride[i],
downsample=nn.Conv2d(curr_numC, num_channels[i], 3, stride[i], 1),
norm_cfg=norm_cfg)]
curr_numC = num_channels[i]
layer.extend([BasicBlock(inplanes=curr_numC, planes=num_channels[i], stride=1,
downsample=None, norm_cfg=norm_cfg) for _ in range(num_layer[i] - 1)])
layers.append(nn.Sequential(*layer))
else:
assert False
self.layers = nn.Sequential(*layers)
self.with_cp = with_cp
def forward(self, x):
"""
Args:
x: (B, C=64, Dy, Dx)
Returns:
feats: List[
(B, 2*C, Dy/2, Dx/2),
(B, 4*C, Dy/4, Dx/4),
(B, 8*C, Dy/8, Dx/8),
]
"""
feats = []
x_tmp = x
for lid, layer in enumerate(self.layers):
if self.with_cp:
x_tmp = checkpoint.checkpoint(layer, x_tmp)
else:
x_tmp = layer(x_tmp)
if lid in self.backbone_output_ids:
feats.append(x_tmp)
return feats
class BasicBlock3D(nn.Module):
def __init__(self,
channels_in, channels_out, stride=1, downsample=None):
super(BasicBlock3D, self).__init__()
self.conv1 = ConvModule(
channels_in,
channels_out,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d', ),
act_cfg=dict(type='ReLU',inplace=True))
self.conv2 = ConvModule(
channels_out,
channels_out,
kernel_size=3,
stride=1,
padding=1,
bias=False,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d', ),
act_cfg=None)
self.downsample = downsample
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
if self.downsample is not None:
identity = self.downsample(x)
else:
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = x + identity
return self.relu(x)
@BACKBONES.register_module()
class CustomResNet3D(nn.Module):
def __init__(
self,
numC_input,
num_layer=[2, 2, 2],
num_channels=None,
stride=[2, 2, 2],
backbone_output_ids=None,
with_cp=False,
):
super(CustomResNet3D, self).__init__()
# build backbone
assert len(num_layer) == len(stride)
num_channels = [numC_input * 2 ** (i + 1) for i in range(len(num_layer))] \
if num_channels is None else num_channels
self.backbone_output_ids = range(len(num_layer)) \
if backbone_output_ids is None else backbone_output_ids
layers = []
curr_numC = numC_input
for i in range(len(num_layer)):
layer = [
BasicBlock3D(
curr_numC,
num_channels[i],
stride=stride[i],
downsample=ConvModule(
curr_numC,
num_channels[i],
kernel_size=3,
stride=stride[i],
padding=1,
bias=False,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d', ),
act_cfg=None))
]
curr_numC = num_channels[i]
layer.extend([
BasicBlock3D(curr_numC, curr_numC)
for _ in range(num_layer[i] - 1)
])
layers.append(nn.Sequential(*layer))
self.layers = nn.Sequential(*layers)
self.with_cp = with_cp
def forward(self, x):
"""
Args:
x: (B, C, Dz, Dy, Dx)
Returns:
feats: List[
(B, C, Dz, Dy, Dx),
(B, 2C, Dz/2, Dy/2, Dx/2),
(B, 4C, Dz/4, Dy/4, Dx/4),
]
"""
feats = []
x_tmp = x
for lid, layer in enumerate(self.layers):
if self.with_cp:
x_tmp = checkpoint.checkpoint(layer, x_tmp)
else:
x_tmp = layer(x_tmp)
if lid in self.backbone_output_ids:
feats.append(x_tmp)
return feats
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer, trunc_normal_init, build_conv_layer
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmcv.cnn.utils.weight_init import constant_init
from mmcv.runner import _load_checkpoint
from mmcv.runner.base_module import BaseModule, ModuleList
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm
import torch.utils.checkpoint as checkpoint
from mmseg.ops import resize
from mmdet3d.utils import get_root_logger
from mmdet3d.models.builder import BACKBONES
from mmcv.cnn.bricks.registry import ATTENTION
from torch.nn.modules.utils import _pair as to_2tuple
from collections import OrderedDict
def swin_convert(ckpt):
new_ckpt = OrderedDict()
def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x
def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x
for k, v in ckpt.items():
if k.startswith('head'):
continue
elif k.startswith('layers'):
new_v = v
if 'attn.' in k:
new_k = k.replace('attn.', 'attn.w_msa.')
elif 'mlp.' in k:
if 'mlp.fc1.' in k:
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
elif 'mlp.fc2.' in k:
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
else:
new_k = k.replace('mlp.', 'ffn.')
elif 'downsample' in k:
new_k = k
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(v)
else:
new_k = k
new_k = new_k.replace('layers', 'stages', 1)
elif k.startswith('patch_embed'):
new_v = v
if 'proj' in k:
new_k = k.replace('proj', 'projection')
else:
new_k = k
else:
new_v = v
new_k = k
new_ckpt[new_k] = new_v
return new_ckpt
# Modified from Pytorch-Image-Models
class PatchEmbed(BaseModule):
"""Image to Patch Embedding V2.
We use a conv layer to implement PatchEmbed.
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (dict, optional): The config dict for conv layers type
selection. Default: None.
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int): The slide stride of embedding conv.
Default: None (Default to be equal with kernel_size).
padding (int): The padding length of embedding conv. Default: 0.
dilation (int): The dilation rate of embedding conv. Default: 1.
pad_to_patch_size (bool, optional): Whether to pad feature map shape
to multiple patch size. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=768,
conv_type=None,
kernel_size=16,
stride=16,
padding=0,
dilation=1,
pad_to_patch_size=True,
norm_cfg=None,
init_cfg=None):
super(PatchEmbed, self).__init__()
self.embed_dims = embed_dims
self.init_cfg = init_cfg
if stride is None:
stride = kernel_size
self.pad_to_patch_size = pad_to_patch_size
# The default setting of patch size is equal to kernel size.
patch_size = kernel_size
if isinstance(patch_size, int):
patch_size = to_2tuple(patch_size)
elif isinstance(patch_size, tuple):
if len(patch_size) == 1:
patch_size = to_2tuple(patch_size[0])
assert len(patch_size) == 2, \
f'The size of patch should have length 1 or 2, ' \
f'but got {len(patch_size)}'
self.patch_size = patch_size
# Use conv layer to embed
conv_type = conv_type or 'Conv2d'
self.projection = build_conv_layer(
dict(type=conv_type),
in_channels=in_channels,
out_channels=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
def forward(self, x):
H, W = x.shape[2], x.shape[3]
# TODO: Process overlapping op
if self.pad_to_patch_size:
# Modify H, W to multiple of patch size.
if H % self.patch_size[0] != 0:
x = F.pad(
x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
if W % self.patch_size[1] != 0:
x = F.pad(
x, (0, self.patch_size[1] - W % self.patch_size[1], 0, 0))
x = self.projection(x)
self.DH, self.DW = x.shape[2], x.shape[3]
x = x.flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer use nn.Unfold to group feature map by kernel_size, and use norm
and linear layer to embed grouped feature map.
Args:
in_channels (int): The num of input channels.
out_channels (int): The num of output channels.
stride (int | tuple): the stride of the sliding length in the
unfold layer. Defaults: 2. (Default to be equal with kernel_size).
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Defaults: None.
"""
def __init__(self,
in_channels,
out_channels,
stride=2,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.sampler = nn.Unfold(
kernel_size=stride, dilation=1, padding=0, stride=stride)
sample_dim = stride**2 * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, hw_shape):
"""
x: x.shape -> [B, H*W, C]
hw_shape: (H, W)
"""
B, L, C = x.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# stride is fixed to be equal to kernel_size.
if (H % self.stride != 0) or (W % self.stride != 0):
x = F.pad(x, (0, W % self.stride, 0, H % self.stride))
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
x = self.sampler(x) # B, 4*C, H/2*W/2
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
down_hw_shape = (H + 1) // 2, (W + 1) // 2
return x, down_hw_shape
@ATTENTION.register_module()
class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.0
init_cfg (dict | None, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
init_cfg=None):
super().__init__()
self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
self.init_cfg = init_cfg
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# About 2x faster than original impl
Wh, Ww = self.window_size
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
rel_position_index = rel_index_coords + rel_index_coords.T
rel_position_index = rel_position_index.flip(1).contiguous()
self.register_buffer('relative_position_index', rel_position_index)
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
self.softmax = nn.Softmax(dim=-1)
def init_weights(self):
trunc_normal_init(self.relative_position_bias_table, std=0.02)
def forward(self, x, mask=None):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor | None, Optional): mask with shape of (num_windows,
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@staticmethod
def double_step_seq(step1, len1, step2, len2):
seq1 = torch.arange(0, step1 * len1, step1)
seq2 = torch.arange(0, step2 * len2, step2)
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
@ATTENTION.register_module()
class ShiftWindowMSA(BaseModule):
"""Shift Window Multihead Self-Attention Module.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window.
shift_size (int, optional): The shift step of each window towards
right-bottom. If zero, act as regular window-msa. Defaults to 0.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Defaults: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Defaults: 0.
proj_drop_rate (float, optional): Dropout ratio of output.
Defaults: 0.
dropout_layer (dict, optional): The dropout_layer used before output.
Defaults: dict(type='DropPath', drop_prob=0.).
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
shift_size=0,
qkv_bias=True,
qk_scale=None,
attn_drop_rate=0,
proj_drop_rate=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
init_cfg=None):
super().__init__(init_cfg)
self.window_size = window_size
self.shift_size = shift_size
assert 0 <= self.shift_size < self.window_size
self.w_msa = WindowMSA(
embed_dims=embed_dims,
num_heads=num_heads,
window_size=to_2tuple(window_size),
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=proj_drop_rate,
init_cfg=None)
self.drop = build_dropout(dropout_layer)
def forward(self, query, hw_shape):
B, L, C = query.shape
H, W = hw_shape
assert L == H * W, 'input feature has wrong size'
query = query.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
H_pad, W_pad = query.shape[1], query.shape[2]
# cyclic shift
if self.shift_size > 0:
shifted_query = torch.roll(
query,
shifts=(-self.shift_size, -self.shift_size),
dims=(1, 2))
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, H_pad, W_pad, 1),
device=query.device) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# nW, window_size, window_size, 1
mask_windows = self.window_partition(img_mask)
mask_windows = mask_windows.view(
-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
else:
shifted_query = query
attn_mask = None
# nW*B, window_size, window_size, C
query_windows = self.window_partition(shifted_query)
# nW*B, window_size*window_size, C
query_windows = query_windows.view(-1, self.window_size**2, C)
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows = self.w_msa(query_windows, mask=attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size,
self.window_size, C)
# B H' W' C
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = self.drop(x)
return x
def window_reverse(self, windows, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
window_size = self.window_size
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
def window_partition(self, x):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
window_size = self.window_size
x = x.view(B, H // window_size, window_size, W // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
class SwinBlock(BaseModule):
""""
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
window size (int, optional): The local window scale. Default: 7.
shift (bool): whether to shift window or not. Default False.
qkv_bias (int, optional): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization.
Default: dict(type='LN').
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
window_size=7,
shift=False,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(SwinBlock, self).__init__()
self.init_cfg = init_cfg
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = ShiftWindowMSA(
embed_dims=embed_dims,
num_heads=num_heads,
window_size=window_size,
shift_size=window_size // 2 if shift else 0,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop_rate=attn_drop_rate,
proj_drop_rate=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
init_cfg=None)
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=2,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
add_identity=True,
init_cfg=None)
def forward(self, x, hw_shape):
identity = x
x = self.norm1(x)
x = self.attn(x, hw_shape)
x = x + identity
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
class SwinBlockSequence(BaseModule):
"""Implements one stage in Swin Transformer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
depth (int): The number of blocks in this stage.
window size (int): The local window scale. Default: 7.
qkv_bias (int): enable bias for qkv if True. Default: True.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
drop_rate (float, optional): Dropout rate. Default: 0.
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.2.
downsample (BaseModule | None, optional): The downsample operation
module. Default: None.
act_cfg (dict, optional): The config dict of activation function.
Default: dict(type='GELU').
norm_cfg (dict, optional): The config dict of nomalization.
Default: dict(type='LN').
init_cfg (dict | list | None, optional): The init config.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
depth,
window_size=7,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
downsample=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None,
with_cp=True):
super().__init__()
self.init_cfg = init_cfg
drop_path_rate = drop_path_rate if isinstance(
drop_path_rate,
list) else [deepcopy(drop_path_rate) for _ in range(depth)]
self.blocks = ModuleList()
for i in range(depth):
block = SwinBlock(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=feedforward_channels,
window_size=window_size,
shift=False if i % 2 == 0 else True,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate[i],
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=None)
self.blocks.append(block)
self.downsample = downsample
self.with_cp = with_cp
def forward(self, x, hw_shape):
for block in self.blocks:
if self.with_cp:
x = checkpoint.checkpoint(block, x, hw_shape)
else:
x = block(x, hw_shape)
if self.downsample:
x_down, down_hw_shape = self.downsample(x, hw_shape)
return x_down, down_hw_shape, x, hw_shape
else:
return x, hw_shape, x, hw_shape
@BACKBONES.register_module()
class SwinTransformer(BaseModule):
""" Swin Transformer
A PyTorch implement of : `Swin Transformer:
Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/abs/2103.14030
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
pretrain_img_size (int | tuple[int]): The size of input image when
pretrain. Defaults: 224.
in_channels (int): The num of input channels.
Defaults: 3.
embed_dims (int): The feature dimension. Default: 96.
patch_size (int | tuple[int]): Patch size. Default: 4.
window_size (int): Window size. Default: 7.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
Default: 4.
depths (tuple[int]): Depths of each Swin Transformer stage.
Default: (2, 2, 6, 2).
num_heads (tuple[int]): Parallel attention heads of each Swin
Transformer stage. Default: (3, 6, 12, 24).
strides (tuple[int]): The patch merging or patch embedding stride of
each Swin Transformer stage. (In swin, we set kernel size equal to
stride.) Default: (4, 2, 2, 2).
out_indices (tuple[int]): Output from which stages.
Default: (0, 1, 2, 3).
qkv_bias (bool, optional): If True, add a learnable bias to query, key,
value. Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
patch_norm (bool): If add a norm layer for patch embed and patch
merging. Default: True.
drop_rate (float): Dropout rate. Defaults: 0.
attn_drop_rate (float): Attention dropout rate. Default: 0.
drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults: False.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LN').
norm_cfg (dict): Config dict for normalization layer at
output of backone. Defaults: dict(type='LN').
pretrain_style (str): Choose to use official or mmcls pretrain weights.
Default: official.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
pretrain_img_size=224,
in_channels=3,
embed_dims=96,
patch_size=4,
window_size=7,
mlp_ratio=4,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
strides=(4, 2, 2, 2),
out_indices=(0, 1, 2, 3),
qkv_bias=True,
qk_scale=None,
patch_norm=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
use_abs_pos_embed=False,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
pretrain_style='official',
pretrained=None,
init_cfg=None,
with_cp=True,
return_stereo_feat=False,
output_missing_index_as_none=False,
frozen_stages=-1):
super(SwinTransformer, self).__init__()
if isinstance(pretrain_img_size, int):
pretrain_img_size = to_2tuple(pretrain_img_size)
elif isinstance(pretrain_img_size, tuple):
if len(pretrain_img_size) == 1:
pretrain_img_size = to_2tuple(pretrain_img_size[0])
assert len(pretrain_img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(pretrain_img_size)}'
assert pretrain_style in ['official', 'mmcls'], 'We only support load '
'official ckpt and mmcls ckpt.'
if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
else:
raise TypeError('pretrained must be a str or None')
num_layers = len(depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.pretrain_style = pretrain_style
self.pretrained = pretrained
self.init_cfg = init_cfg
self.frozen_stages = frozen_stages
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dims=embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=strides[0],
pad_to_patch_size=True,
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
if self.use_abs_pos_embed:
patch_row = pretrain_img_size[0] // patch_size
patch_col = pretrain_img_size[1] // patch_size
num_patches = patch_row * patch_col
self.absolute_pos_embed = nn.Parameter(
torch.zeros((1, num_patches, embed_dims)))
self.drop_after_pos = nn.Dropout(p=drop_rate)
# stochastic depth
total_depth = sum(depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
self.stages = ModuleList()
in_channels = embed_dims
for i in range(num_layers):
if i < num_layers - 1:
downsample = PatchMerging(
in_channels=in_channels,
out_channels=2 * in_channels,
stride=strides[i + 1],
norm_cfg=norm_cfg if patch_norm else None,
init_cfg=None)
else:
downsample = None
stage = SwinBlockSequence(
embed_dims=in_channels,
num_heads=num_heads[i],
feedforward_channels=mlp_ratio * in_channels,
depth=depths[i],
window_size=window_size,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[:depths[i]],
downsample=downsample,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=None,
with_cp=with_cp)
self.stages.append(stage)
dpr = dpr[depths[i]:]
if downsample:
in_channels = downsample.out_channels
self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
# Add a norm layer for each output
for i in out_indices:
layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
layer_name = f'norm{i}'
self.add_module(layer_name, layer)
self.output_missing_index_as_none = output_missing_index_as_none
self._freeze_stages()
self.return_stereo_feat = return_stereo_feat
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
if self.frozen_stages >= 1 and self.use_abs_pos_embed:
self.absolute_pos_embed.requires_grad = False
if self.frozen_stages >= 2:
self.drop_after_pos.eval()
for i in range(0, self.frozen_stages - 1):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self):
if self.pretrained is None:
super().init_weights()
if self.use_abs_pos_embed:
trunc_normal_init(self.absolute_pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, Linear):
trunc_normal_init(m.weight, std=.02)
if m.bias is not None:
constant_init(m.bias, 0)
elif isinstance(m, LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
elif isinstance(self.pretrained, str):
logger = get_root_logger()
ckpt = _load_checkpoint(
self.pretrained, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
state_dict = ckpt['state_dict']
elif 'model' in ckpt:
state_dict = ckpt['model']
else:
state_dict = ckpt
if self.pretrain_style == 'official':
state_dict = swin_convert(state_dict)
# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
# if list(state_dict.keys())[0].startswith('backbone.'):
# state_dict = {k[9:]: v for k, v in state_dict.items()}
# reshape absolute position embedding
if state_dict.get('absolute_pos_embed') is not None:
absolute_pos_embed = state_dict['absolute_pos_embed']
N1, L, C1 = absolute_pos_embed.size()
N2, C2, H, W = self.absolute_pos_embed.size()
if N1 != N2 or C1 != C2 or L != H * W:
logger.warning('Error in loading absolute_pos_embed, pass')
else:
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
# interpolate position bias table if needed
relative_position_bias_table_keys = [
k for k in state_dict.keys()
if 'relative_position_bias_table' in k
]
for table_key in relative_position_bias_table_keys:
table_pretrained = state_dict[table_key]
table_current = self.state_dict()[table_key]
L1, nH1 = table_pretrained.size()
L2, nH2 = table_current.size()
if nH1 != nH2:
logger.warning(f'Error in loading {table_key}, pass')
else:
if L1 != L2:
S1 = int(L1**0.5)
S2 = int(L2**0.5)
table_pretrained_resized = resize(
table_pretrained.permute(1, 0).reshape(
1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
state_dict[table_key] = table_pretrained_resized.view(
nH2, L2).permute(1, 0).contiguous()
# load state_dict
self.load_state_dict(state_dict, False)
def forward(self, x):
x = self.patch_embed(x)
hw_shape = (self.patch_embed.DH, self.patch_embed.DW)
if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
if i == 0 and self.return_stereo_feat:
out = out.view(-1, *out_hw_shape,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(out)
out = out.view(-1, *out_hw_shape,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)
elif self.output_missing_index_as_none:
outs.append(None)
return outs
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(SwinTransformer, self).train(mode)
self._freeze_stages()
\ No newline at end of file
from .bev_centerpoint_head import BEV_CenterHead, Centerness_Head
from .bev_occ_head import BEVOCCHead2D, BEVOCCHead3D, BEVOCCHead2D_V2
__all__ = ['Centerness_Head', 'BEV_CenterHead', 'BEVOCCHead2D', 'BEVOCCHead3D', 'BEVOCCHead2D_V2']
\ No newline at end of file
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import torch
from mmcv.cnn import ConvModule, build_conv_layer
from mmcv.runner import BaseModule
from torch import nn
from mmdet3d.core import (circle_nms, draw_heatmap_gaussian, gaussian_radius,
xywhr2xyxyr)
from ...core.post_processing import nms_bev
from mmdet3d.models import builder
from mmdet3d.models.utils import clip_sigmoid
from mmdet.core import build_bbox_coder, multi_apply, reduce_mean
from mmdet3d.models.builder import HEADS, build_loss
@HEADS.register_module(force=True)
class SeparateHead(BaseModule):
"""SeparateHead for CenterHead.
Args:
in_channels (int): Input channels for conv_layer.
heads (dict): Conv information.
head_conv (int, optional): Output channels.
Default: 64.
final_kernel (int, optional): Kernel size for the last conv layer.
Default: 1.
init_bias (float, optional): Initial bias. Default: -2.19.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
"""
def __init__(self,
in_channels,
heads,
head_conv=64,
final_kernel=1,
init_bias=-2.19,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
init_cfg=None,
**kwargs):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(SeparateHead, self).__init__(init_cfg=init_cfg)
self.heads = heads
self.init_bias = init_bias
for head in self.heads:
# 该head的输出通道和卷积数量.
classes, num_conv = self.heads[head]
conv_layers = []
c_in = in_channels
for i in range(num_conv - 1):
conv_layers.append(
ConvModule(
c_in,
head_conv,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=bias,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg))
c_in = head_conv
conv_layers.append(
build_conv_layer(
conv_cfg,
head_conv,
classes,
kernel_size=final_kernel,
stride=1,
padding=final_kernel // 2,
bias=True))
conv_layers = nn.Sequential(*conv_layers)
self.__setattr__(head, conv_layers)
if init_cfg is None:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def init_weights(self):
"""Initialize weights."""
super().init_weights()
for head in self.heads:
if head == 'heatmap':
self.__getattr__(head)[-1].bias.data.fill_(self.init_bias)
def forward(self, x):
"""Forward function for SepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of
[B, N, H, W].
"""
ret_dict = dict()
for head in self.heads:
ret_dict[head] = self.__getattr__(head)(x)
return ret_dict
@HEADS.register_module(force=True)
class DCNSeparateHead(BaseModule):
r"""DCNSeparateHead for CenterHead.
.. code-block:: none
/-----> DCN for heatmap task -----> heatmap task.
feature
\-----> DCN for regression tasks -----> regression tasks
Args:
in_channels (int): Input channels for conv_layer.
num_cls (int): Number of classes.
heads (dict): Conv information.
dcn_config (dict): Config of dcn layer.
head_conv (int, optional): Output channels.
Default: 64.
final_kernel (int, optional): Kernel size for the last conv
layer. Default: 1.
init_bias (float, optional): Initial bias. Default: -2.19.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
""" # noqa: W605
def __init__(self,
in_channels,
num_cls,
heads,
dcn_config,
head_conv=64,
final_kernel=1,
init_bias=-2.19,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
init_cfg=None,
**kwargs):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(DCNSeparateHead, self).__init__(init_cfg=init_cfg)
if 'heatmap' in heads:
heads.pop('heatmap')
# feature adaptation with dcn
# use separate features for classification / regression
self.feature_adapt_cls = build_conv_layer(dcn_config)
self.feature_adapt_reg = build_conv_layer(dcn_config)
# heatmap prediction head
cls_head = [
ConvModule(
in_channels,
head_conv,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
bias=bias,
norm_cfg=norm_cfg),
build_conv_layer(
conv_cfg,
head_conv,
num_cls,
kernel_size=3,
stride=1,
padding=1,
bias=bias)
]
self.cls_head = nn.Sequential(*cls_head)
self.init_bias = init_bias
# other regression target
self.task_head = SeparateHead(
in_channels,
heads,
head_conv=head_conv,
final_kernel=final_kernel,
bias=bias)
if init_cfg is None:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def init_weights(self):
"""Initialize weights."""
super().init_weights()
self.cls_head[-1].bias.data.fill_(self.init_bias)
def forward(self, x):
"""Forward function for DCNSepHead.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
-dim (torch.Tensor): Size value with the shape
of [B, 3, H, W].
-rot (torch.Tensor): Rotation value with the
shape of [B, 2, H, W].
-vel (torch.Tensor): Velocity value with the
shape of [B, 2, H, W].
-heatmap (torch.Tensor): Heatmap with the shape of
[B, N, H, W].
"""
center_feat = self.feature_adapt_cls(x)
reg_feat = self.feature_adapt_reg(x)
cls_score = self.cls_head(center_feat)
ret = self.task_head(reg_feat)
ret['heatmap'] = cls_score
return ret
@HEADS.register_module()
class BEV_CenterHead(BaseModule):
"""CenterHead for CenterPoint.
Args:
in_channels (list[int] | int, optional): Channels of the input
feature map. Default: [128].
tasks (list[dict], optional): Task information including class number
and class names. Default: None.
train_cfg (dict, optional): Train-time configs. Default: None.
test_cfg (dict, optional): Test-time configs. Default: None.
bbox_coder (dict, optional): Bbox coder configs. Default: None.
common_heads (dict, optional): Conv information for common heads.
Default: dict().
loss_cls (dict, optional): Config of classification loss function.
Default: dict(type='GaussianFocalLoss', reduction='mean').
loss_bbox (dict, optional): Config of regression loss function.
Default: dict(type='L1Loss', reduction='none').
separate_head (dict, optional): Config of separate head. Default: dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3)
share_conv_channel (int, optional): Output channels for share_conv
layer. Default: 64.
num_heatmap_convs (int, optional): Number of conv layers for heatmap
conv layer. Default: 2.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
"""
def __init__(self,
in_channels=[128],
tasks=None,
train_cfg=None,
test_cfg=None,
bbox_coder=None,
common_heads=dict(),
loss_cls=dict(type='GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(
type='L1Loss', reduction='none', loss_weight=0.25),
separate_head=dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3),
share_conv_channel=64,
num_heatmap_convs=2,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
norm_bbox=True,
init_cfg=None,
task_specific=True):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(BEV_CenterHead, self).__init__(init_cfg=init_cfg)
num_classes = [len(t['class_names']) for t in tasks] # 记录不同task(SeparateHead)负责检测的类别数.
self.class_names = [t['class_names'] for t in tasks] # 记录不同task(SeparateHead)负责检测的类别名.
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.in_channels = in_channels
self.num_classes = num_classes
self.norm_bbox = norm_bbox
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False
# a shared convolution
self.shared_conv = ConvModule(
in_channels,
share_conv_channel,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=bias)
# 每个task建立对应的head.
self.task_heads = nn.ModuleList()
for num_cls in num_classes:
# common_heads = dict(
# reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)),
heads = copy.deepcopy(common_heads)
heads.update(dict(heatmap=(num_cls, num_heatmap_convs)))
separate_head.update(
in_channels=share_conv_channel, heads=heads, num_cls=num_cls)
self.task_heads.append(builder.build_head(separate_head))
self.with_velocity = 'vel' in common_heads.keys()
self.task_specific = task_specific
def forward_single(self, x):
"""Forward function for CenterPoint.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
list[dict]: Output results for tasks.
"""
ret_dicts = []
x = self.shared_conv(x) # (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
for task in self.task_heads:
ret_dicts.append(task(x))
# ret_dicts: [dict0, dict1, ...] len = SeparateHead的数量
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
return ret_dicts
def forward(self, feats):
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
results: Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
"""
return multi_apply(self.forward_single, feats)
def _gather_feat(self, feat, ind, mask=None):
"""Gather feature map.
Given feature map and index, return indexed feature map.
Args:
feat (torch.tensor): Feature map with the shape of [B, H*W, 10].
ind (torch.Tensor): Index of the ground truth boxes with the
shape of [B, max_obj].
mask (torch.Tensor, optional): Mask of the feature map with the
shape of [B, max_obj]. Default: None.
Returns:
torch.Tensor: Feature map after gathering with the shape
of [B, max_obj, 10].
"""
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def get_targets(self, gt_bboxes_3d, gt_labels_3d):
"""Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. # List[(N_gt0, 7/9), (N_gt1, 7/9), ...]
gt_labels_3d (list[torch.Tensor]): Labels of boxes. # List[(N_gt0, ), (N_gt1, ), ...]
Returns:
Returns:
tuple[list[torch.Tensor]]: (
heatmaps: List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
anno_boxes:
inds:
masks:
)
"""
heatmaps, anno_boxes, inds, masks = multi_apply(
self.get_targets_single, gt_bboxes_3d, gt_labels_3d)
# heatmaps: # Tuple(List[(N_cls0, H, W), (N_cls1, H, W), ...], ...) len = batch_size
# anno_boxes: # Tuple(List[(max_objs, 10), (max_objs, 10), ...], ...) len = batch_size
# inds: # Tuple(List[(max_objs, ), (max_objs, ), ...], ...)
# masks: # Tuple(List[(max_objs, ), (max_objs, ), ...], ...)
# Transpose heatmaps
# List[List[(N_cls0, H, W), (N_cls0, H, W), ...], List[(N_cls1, H, W), (N_cls1, H, W), ...], ...] len = num of SeparateHead
heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps] # List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
# Transpose anno_boxes
anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes] # List[(B, max_objs, 10), (B, max_objs, 10), ...] len = num of SeparateHead
# Transpose inds
inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds] # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
# Transpose inds
masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks] # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
return heatmaps, anno_boxes, inds, masks
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
"""Generate training targets for a single sample.
Args:
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes. # (N_gt, 7/9)
gt_labels_3d (torch.Tensor): Labels of boxes. # (N_gt, )
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- heatmaps: list[torch.Tensor]: Heatmap scores. # List[(N_cls0, H, W), (N_cls1, H, W), ...]
len = num of tasks
- anno_boxes: list[torch.Tensor]: Ground truth boxes. # List[(max_objs, 10), (max_objs, 10), ...]
- inds: list[torch.Tensor]: Indexes indicating the position
of the valid boxes. # List[(max_objs, ), (max_objs, ), ...]
- masks: list[torch.Tensor]: Masks indicating which boxes
are valid. # List[(max_objs, ), (max_objs, ), ...]
"""
device = gt_labels_3d.device
gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]),
dim=1).to(device) # (N_gt, 7/9)
max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg']
grid_size = torch.tensor(self.train_cfg['grid_size']) # (Dx, Dy, Dz)
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor'] # (W, H)
# reorganize the gt_dict by tasks
task_masks = []
flag = 0
for class_name in self.class_names:
# class_name: 不同task(SeparateHead)负责检测的类别名.
task_masks.append([
torch.where(gt_labels_3d == class_name.index(i) + flag)
for i in class_name
])
flag += len(class_name)
# task_masks: List[task_mask0, task_mask1, ...] len = number of SeparateHeads
# task_mask: List[((N_gt0, ), ), ((N_gt1, ), ), ...] len = number of class
task_boxes = []
task_classes = []
flag2 = 0
for idx, mask in enumerate(task_masks):
# mask: 不同task(SeparateHead)的mask, 每个task负责检测一组不同类别的目标.
# List[((N_gt0, ), ), ((N_gt1, ), ), ...], # N_gt_task=N_gt0+N_gt1+..., 表示当前task负责检测的gt_boxes的数量.
task_box = []
task_class = []
for m in mask:
task_box.append(gt_bboxes_3d[m])
# 0 is background for each task, so we need to add 1 here.
task_class.append(gt_labels_3d[m] + 1 - flag2)
task_boxes.append(torch.cat(task_box, axis=0).to(device))
task_classes.append(torch.cat(task_class).long().to(device))
flag2 += len(mask)
# 记录不同task负责检测的gt_boxes和gt_classes:
# task_boxes: List[(N_gt_task0, 7/9), (N_gt_task1, 7/9), ...]
# task_classes: List[(N_gt_task0, ), (N_gt_task1, ), ...]
draw_gaussian = draw_heatmap_gaussian
heatmaps, anno_boxes, inds, masks = [], [], [], []
for idx, task_head in enumerate(self.task_heads):
heatmap = gt_bboxes_3d.new_zeros(
(len(self.class_names[idx]), feature_map_size[1],
feature_map_size[0])) # (N_cls, H, W) N_cls表示当前task_head负责检测的类别数目.
if self.with_velocity:
anno_box = gt_bboxes_3d.new_zeros((max_objs, 10),
dtype=torch.float32) # (max_objs, 10)
else:
anno_box = gt_bboxes_3d.new_zeros((max_objs, 8),
dtype=torch.float32)
ind = gt_labels_3d.new_zeros((max_objs, ), dtype=torch.int64) # (max_objs, )
mask = gt_bboxes_3d.new_zeros((max_objs, ), dtype=torch.uint8) # (max_objs, )
num_objs = min(task_boxes[idx].shape[0], max_objs) # 当前task_head负责检测的目标.
for k in range(num_objs):
cls_id = task_classes[idx][k] - 1 # 当前目标的cls_id, cls_id是相对task group内的.
width = task_boxes[idx][k][3] # dx
length = task_boxes[idx][k][4] # dy
# 当前目标在feature map上的width和length
width = width / voxel_size[0] / self.train_cfg[
'out_size_factor']
length = length / voxel_size[1] / self.train_cfg[
'out_size_factor']
if width > 0 and length > 0:
# 计算gaussian半径
radius = gaussian_radius(
(length, width),
min_overlap=self.train_cfg['gaussian_overlap'])
radius = max(self.train_cfg['min_radius'], int(radius))
# be really careful for the coordinate system of
# your box annotation.
x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][
1], task_boxes[idx][k][2] # 当前目标的中心坐标.
# 计算gt_box中心点在feature map中对应的位置.
coor_x = (
x - pc_range[0]
) / voxel_size[0] / self.train_cfg['out_size_factor']
coor_y = (
y - pc_range[1]
) / voxel_size[1] / self.train_cfg['out_size_factor']
center = torch.tensor([coor_x, coor_y],
dtype=torch.float32,
device=device)
center_int = center.to(torch.int32)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if not (0 <= center_int[0] < feature_map_size[0]
and 0 <= center_int[1] < feature_map_size[1]):
continue
# 根据目标中心点在feature map中对应的位置、高斯半径来设置heatmap.
draw_gaussian(heatmap[cls_id], center_int, radius)
new_idx = k
x, y = center_int[0], center_int[1]
assert (y * feature_map_size[0] + x <
feature_map_size[0] * feature_map_size[1])
# 记录正样本在feature map中的位置.
ind[new_idx] = y * feature_map_size[0] + x
mask[new_idx] = 1
# TODO: support other outdoor dataset
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
if self.norm_bbox:
box_dim = box_dim.log()
if self.with_velocity:
vx, vy = task_boxes[idx][k][7:]
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device), # tx, ty
z.unsqueeze(0), box_dim, # z, log(dx), log(dy), log(dz)
torch.sin(rot).unsqueeze(0), # sin(rot)
torch.cos(rot).unsqueeze(0), # cos(rot)
vx.unsqueeze(0), # vx
vy.unsqueeze(0) # vy
]) # [tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy]
else:
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0)
])
heatmaps.append(heatmap) # append (N_cls, H, W)
anno_boxes.append(anno_box) # append (max_objs, 10)
masks.append(mask) # append (max_objs, )
inds.append(ind) # append (max_objs, )
return heatmaps, anno_boxes, inds, masks
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
"""Loss function for CenterHead.
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. # List[(N_gt0, 7/9), (N_gt1, 7/9), ...]
gt_labels_3d (list[torch.Tensor]): Labels of boxes. # List[(N_gt0, ), (N_gt1, ), ...]
preds_dicts (dict): Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
Returns:
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
"""
heatmaps, anno_boxes, inds, masks = self.get_targets(
gt_bboxes_3d, gt_labels_3d)
# heatmaps: # List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
# anno_boxes: # List[(B, max_objs, 10), (B, max_objs, 10), ...] len = num of SeparateHead
# inds: # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
# masks: # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
loss_dict = dict()
if not self.task_specific:
loss_dict['loss'] = 0
for task_id, preds_dict in enumerate(preds_dicts):
# task_id: SeparateHead idx
# preds_dict: List[dict0, ...] len = num levels, 对于center_point len = 1
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
# heatmap focal loss
preds_dict[0]['heatmap'] = clip_sigmoid(preds_dict[0]['heatmap'])
num_pos = heatmaps[task_id].eq(1).float().sum().item()
cls_avg_factor = torch.clamp(
reduce_mean(heatmaps[task_id].new_tensor(num_pos)),
min=1).item()
loss_heatmap = self.loss_cls(
preds_dict[0]['heatmap'], # (B, cur_N_cls, H, W)
heatmaps[task_id], # (B, cur_N_cls, H, W)
avg_factor=cls_avg_factor
)
# (B, max_objs, 10) 10: (tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy)
target_box = anno_boxes[task_id]
# reconstruct the anno_box from multiple reg heads
preds_dict[0]['anno_box'] = torch.cat(
(
preds_dict[0]['reg'],
preds_dict[0]['height'],
preds_dict[0]['dim'],
preds_dict[0]['rot'],
preds_dict[0]['vel'],
),
dim=1,
) # (B, 10, H, W) 10: (tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy)
# Regression loss for dimension, offset, height, rotation
num = masks[task_id].float().sum() # 正样本的数量
ind = inds[task_id] # (B, max_objs)
pred = preds_dict[0]['anno_box'].permute(0, 2, 3, 1).contiguous() # (B, H, W, 10)
pred = pred.view(pred.size(0), -1, pred.size(3)) # (B, H*W, 10)
pred = self._gather_feat(pred, ind) # (B, max_objs, 10)
# (B, max_objs) --> (B, max_objs, 1) --> (B, max_objs, 10)
mask = masks[task_id].unsqueeze(2).expand_as(target_box).float()
num = torch.clamp(
reduce_mean(target_box.new_tensor(num)), min=1e-4).item()
isnotnan = (~torch.isnan(target_box)).float()
mask *= isnotnan # 只监督mask指定的reg预测.
code_weights = self.train_cfg['code_weights']
bbox_weights = mask * mask.new_tensor(code_weights) # 在mask基础上,设置box不同属性的权重. (B, max_objs, 10)
if self.task_specific:
name_list = ['xy', 'z', 'whl', 'yaw', 'vel']
clip_index = [0, 2, 3, 6, 8, 10]
for reg_task_id in range(len(name_list)):
pred_tmp = pred[..., clip_index[reg_task_id]:clip_index[reg_task_id + 1]] # (B, max_objs, K)
target_box_tmp = target_box[..., clip_index[reg_task_id]:clip_index[reg_task_id + 1]] # (B, max_objs, K)
bbox_weights_tmp = bbox_weights[..., clip_index[reg_task_id]:clip_index[reg_task_id + 1]] # (B, max_objs, K)
loss_bbox_tmp = self.loss_bbox(
pred_tmp,
target_box_tmp,
bbox_weights_tmp,
avg_factor=(num + 1e-4))
loss_dict[f'task{task_id}.loss_%s' %
(name_list[reg_task_id])] = loss_bbox_tmp
loss_dict[f'task{task_id}.loss_heatmap'] = loss_heatmap
else:
loss_bbox = self.loss_bbox(
pred, target_box, bbox_weights, avg_factor=num)
loss_dict['loss'] += loss_bbox
loss_dict['loss'] += loss_heatmap
return loss_dict
def get_bboxes(self, preds_dicts, img_metas, img=None, rescale=False):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
ret_list: List[p_list0, p_list1, ...]
p_list: List[(N, 9), (N, ), (N, )]
"""
rets = []
for task_id, preds_dict in enumerate(preds_dicts):
# task_id: SeparateHead idx
# preds_dict: List[dict0, ...] len = num levels, 对于center_point len = 1
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
batch_size = preds_dict[0]['heatmap'].shape[0]
batch_heatmap = preds_dict[0]['heatmap'].sigmoid() # (B, n_cls, H, W)
batch_reg = preds_dict[0]['reg'] # (B, 2, H, W)
batch_hei = preds_dict[0]['height'] # (B, 1, H, W)
if self.norm_bbox:
batch_dim = torch.exp(preds_dict[0]['dim']) # (B, 3, H, W)
else:
batch_dim = preds_dict[0]['dim']
batch_rots = preds_dict[0]['rot'][:, 0].unsqueeze(1) # (B, 1, H, W)
batch_rotc = preds_dict[0]['rot'][:, 1].unsqueeze(1) # (B, 1, H, W)
if 'vel' in preds_dict[0]:
batch_vel = preds_dict[0]['vel'] # (B, 2, H, W)
else:
batch_vel = None
temp = self.bbox_coder.decode(
batch_heatmap,
batch_rots,
batch_rotc,
batch_hei,
batch_dim,
batch_vel,
reg=batch_reg,
task_id=task_id)
# temp: List[p_dict0, p_dict1, ...] len=bs
# p_dict = {
# 'bboxes': boxes3d, # (K', 9)
# 'scores': scores, # (K', )
# 'labels': labels # (K', )
# }
batch_reg_preds = [box['bboxes'] for box in temp] # List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_preds = [box['scores'] for box in temp] # List[(K0, ), (K1, ), ...] len = bs
batch_cls_labels = [box['labels'] for box in temp] # List[(K0, ), (K1, ), ...] len = bs
nms_type = self.test_cfg.get('nms_type')
if isinstance(nms_type, list):
nms_type = nms_type[task_id]
if nms_type == 'circle':
ret_task = []
for i in range(batch_size):
boxes3d = temp[i]['bboxes']
scores = temp[i]['scores']
labels = temp[i]['labels']
centers = boxes3d[:, [0, 1]]
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
keep = torch.tensor(
circle_nms(
boxes.detach().cpu().numpy(),
self.test_cfg['min_radius'][task_id],
post_max_size=self.test_cfg['post_max_size']),
dtype=torch.long,
device=boxes.device)
boxes3d = boxes3d[keep]
scores = scores[keep]
labels = labels[keep]
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
ret_task.append(ret)
rets.append(ret_task)
else:
rets.append(
self.get_task_detections(batch_cls_preds, batch_reg_preds,
batch_cls_labels, img_metas,
task_id))
# rets: List[ret_task0, ret_task1, ...], len = num_tasks
# ret_task: List[p_dict0, p_dict1, ...], len = batch_size
# p_dict: dict{
# bboxes: (K', 9)
# scores: (K', )
# labels: (K', )
# }
# Merge branches results
num_samples = len(rets[0]) # bs
ret_list = []
# 遍历batch, 然后汇总所有task的预测.
for i in range(num_samples):
for k in rets[0][i].keys():
if k == 'bboxes':
bboxes = torch.cat([ret[i][k] for ret in rets]) # 对于bboxes, 直接拼接即可.
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = img_metas[i]['box_type_3d'](
bboxes, self.bbox_coder.code_size)
elif k == 'scores':
scores = torch.cat([ret[i][k] for ret in rets]) # 对于scores, 直接拼接即可.
elif k == 'labels':
flag = 0
for j, num_class in enumerate(self.num_classes): # 对于labels, 要进行调整, 因为预测的label是task组内的.
rets[j][i][k] += flag
flag += num_class
labels = torch.cat([ret[i][k].int() for ret in rets])
ret_list.append([bboxes, scores, labels])
# ret_list: List[p_list0, p_list1, ...]
# p_list: List[(N, 9), (N, ), (N, )]
return ret_list
def get_task_detections(self, batch_cls_preds,
batch_reg_preds, batch_cls_labels, img_metas,
task_id):
"""Rotate nms for each task.
Args:
batch_cls_preds (list[torch.Tensor]): Prediction score with the
shape of [N]. # List[(K0, ), (K1, ), ...] len = bs
batch_reg_preds (list[torch.Tensor]): Prediction bbox with the
shape of [N, 9]. # List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_labels (list[torch.Tensor]): Prediction label with the
shape of [N]. # List[(K0, ), (K1, ), ...] len = bs
img_metas (list[dict]): Meta information of each sample.
Returns:
list[dict[str: torch.Tensor]]: contains the following keys:
-bboxes (torch.Tensor): Prediction bboxes after nms with the
shape of [N, 9].
-scores (torch.Tensor): Prediction scores after nms with the
shape of [N].
-labels (torch.Tensor): Prediction labels after nms with the
shape of [N].
List[p_dict0, p_dict1, ...] len = batch_size
p_dict: dict{
bboxes: (K', 9)
scores: (K', )
labels: (K', )
}
"""
predictions_dicts = []
# 遍历不同batch的topK预测输出.
for i, (box_preds, cls_preds, cls_labels) in enumerate(
zip(batch_reg_preds, batch_cls_preds, batch_cls_labels)):
# box_preds: (K, 9)
# cls_preds: (K, )
# cls_labels: (K, )
default_val = [1.0 for _ in range(len(self.task_heads))]
factor = self.test_cfg.get('nms_rescale_factor',
default_val)[task_id]
if isinstance(factor, list):
# List[float, float, ..] len = 当前task负责预测的类别数.
# 对于box_preds, 使用其对应的factor进行缩放, 一般是放大小目标,缩小大目标.
for cid in range(len(factor)):
box_preds[cls_labels == cid, 3:6] = \
box_preds[cls_labels == cid, 3:6] * factor[cid]
else:
box_preds[:, 3:6] = box_preds[:, 3:6] * factor
# Apply NMS in birdeye view
top_labels = cls_labels.long() # (K, )
top_scores = cls_preds.squeeze(-1) if cls_preds.shape[0] > 1 \
else cls_preds # (K, )
if top_scores.shape[0] != 0:
boxes_for_nms = img_metas[i]['box_type_3d'](
box_preds[:, :], self.bbox_coder.code_size).bev # (K, 5) (x, y, dx, dy, yaw)
# the nms in 3d detection just remove overlap boxes.
if isinstance(self.test_cfg['nms_thr'], list):
nms_thresh = self.test_cfg['nms_thr'][task_id]
else:
nms_thresh = self.test_cfg['nms_thr']
selected = nms_bev(
boxes_for_nms,
top_scores,
thresh=nms_thresh,
pre_max_size=self.test_cfg['pre_max_size'],
post_max_size=self.test_cfg['post_max_size'],
xyxyr2xywhr=False,
)
else:
selected = []
# NMS后再根据factor缩放回原来的尺寸.
if isinstance(factor, list):
for cid in range(len(factor)):
box_preds[top_labels == cid, 3:6] = \
box_preds[top_labels == cid, 3:6] / factor[cid]
else:
box_preds[:, 3:6] = box_preds[:, 3:6] / factor
# if selected is not None:
selected_boxes = box_preds[selected] # (K', 9)
selected_labels = top_labels[selected] # (K', )
selected_scores = top_scores[selected] # (K', )
# finally generate predictions.
if selected_boxes.shape[0] != 0:
predictions_dict = dict(
bboxes=selected_boxes,
scores=selected_scores,
labels=selected_labels)
else:
dtype = batch_reg_preds[0].dtype
device = batch_reg_preds[0].device
predictions_dict = dict(
bboxes=torch.zeros([0, self.bbox_coder.code_size],
dtype=dtype,
device=device),
scores=torch.zeros([0], dtype=dtype, device=device),
labels=torch.zeros([0],
dtype=top_labels.dtype,
device=device))
predictions_dicts.append(predictions_dict)
return predictions_dicts
@HEADS.register_module()
class Centerness_Head(BaseModule):
"""CenterHead for CenterPoint.
Args:
in_channels (list[int] | int, optional): Channels of the input
feature map. Default: [128].
tasks (list[dict], optional): Task information including class number
and class names. Default: None.
train_cfg (dict, optional): Train-time configs. Default: None.
test_cfg (dict, optional): Test-time configs. Default: None.
bbox_coder (dict, optional): Bbox coder configs. Default: None.
common_heads (dict, optional): Conv information for common heads.
Default: dict().
loss_cls (dict, optional): Config of classification loss function.
Default: dict(type='GaussianFocalLoss', reduction='mean').
loss_bbox (dict, optional): Config of regression loss function.
Default: dict(type='L1Loss', reduction='none').
separate_head (dict, optional): Config of separate head. Default: dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3)
share_conv_channel (int, optional): Output channels for share_conv
layer. Default: 64.
num_heatmap_convs (int, optional): Number of conv layers for heatmap
conv layer. Default: 2.
conv_cfg (dict, optional): Config of conv layer.
Default: dict(type='Conv2d')
norm_cfg (dict, optional): Config of norm layer.
Default: dict(type='BN2d').
bias (str, optional): Type of bias. Default: 'auto'.
"""
def __init__(self,
in_channels=[128],
tasks=None,
train_cfg=None,
test_cfg=None,
bbox_coder=None,
common_heads=dict(),
loss_cls=dict(type='GaussianFocalLoss', reduction='mean'),
loss_bbox=dict(
type='L1Loss', reduction='none', loss_weight=0.25),
separate_head=dict(
type='SeparateHead', init_bias=-2.19, final_kernel=3),
share_conv_channel=64,
num_heatmap_convs=2,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d'),
bias='auto',
norm_bbox=True,
init_cfg=None,
task_specific=True,
task_specific_weight=[1, 1, 1, 1, 1]):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(Centerness_Head, self).__init__(init_cfg=init_cfg)
num_classes = [len(t['class_names']) for t in tasks] # 记录不同task(SeparateHead)负责检测的类别数.
self.class_names = [t['class_names'] for t in tasks] # 记录不同task(SeparateHead)负责检测的类别名.
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.in_channels = in_channels
self.num_classes = num_classes
self.norm_bbox = norm_bbox
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False
# a shared convolution
self.shared_conv = ConvModule(
in_channels,
share_conv_channel,
kernel_size=3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=bias)
# 每个task建立对应的head.
self.task_heads = nn.ModuleList()
for num_cls in num_classes:
# common_heads = dict(
# reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)),
heads = copy.deepcopy(common_heads)
heads.update(dict(heatmap=(num_cls, num_heatmap_convs)))
separate_head.update(
in_channels=share_conv_channel, heads=heads, num_cls=num_cls)
self.task_heads.append(builder.build_head(separate_head))
self.with_velocity = 'vel' in common_heads.keys()
self.task_specific = task_specific
self.task_specific_weight = task_specific_weight # [1, 1, 0, 0, 0] # 'xy', 'z', 'whl', 'yaw', 'vel'
def forward_single(self, x):
"""Forward function for CenterPoint.
Args:
x (torch.Tensor): Input feature map with the shape of
[B, 512, 128, 128].
Returns:
list[dict]: Output results for tasks.
"""
ret_dicts = []
x = self.shared_conv(x) # (B, C'=share_conv_channel, H, W)
# 运行不同task_head,
for task in self.task_heads:
ret_dicts.append(task(x))
# ret_dicts: [dict0, dict1, ...] len = SeparateHead的数量
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
return ret_dicts
def forward(self, feats):
"""Forward pass.
Args:
feats (list[torch.Tensor]): Multi-level features, e.g.,
features produced by FPN.
Returns:
results: Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
"""
return multi_apply(self.forward_single, feats)
def _gather_feat(self, feat, ind, mask=None):
"""Gather feature map.
Given feature map and index, return indexed feature map.
Args:
feat (torch.tensor): Feature map with the shape of [B, H*W, 10].
ind (torch.Tensor): Index of the ground truth boxes with the
shape of [B, max_obj].
mask (torch.Tensor, optional): Mask of the feature map with the
shape of [B, max_obj]. Default: None.
Returns:
torch.Tensor: Feature map after gathering with the shape
of [B, max_obj, 10].
"""
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def get_targets(self, gt_bboxes_3d, gt_labels_3d):
"""Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. # List[(N_gt0, 7/9), (N_gt1, 7/9), ...]
gt_labels_3d (list[torch.Tensor]): Labels of boxes. # List[(N_gt0, ), (N_gt1, ), ...]
Returns:
Returns:
tuple[list[torch.Tensor]]: (
heatmaps: List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
anno_boxes:
inds:
masks:
)
"""
heatmaps, anno_boxes, inds, masks = multi_apply(
self.get_targets_single, gt_bboxes_3d, gt_labels_3d)
# heatmaps: # Tuple(List[(N_cls0, H, W), (N_cls1, H, W), ...], ...) len = batch_size
# anno_boxes: # Tuple(List[(max_objs, 10), (max_objs, 10), ...], ...) len = batch_size
# inds: # Tuple(List[(max_objs, ), (max_objs, ), ...], ...)
# masks: # Tuple(List[(max_objs, ), (max_objs, ), ...], ...)
# Transpose heatmaps
# List[List[(N_cls0, H, W), (N_cls0, H, W), ...], List[(N_cls1, H, W), (N_cls1, H, W), ...], ...] len = num of SeparateHead
heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps] # List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
# Transpose anno_boxes
anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes] # List[(B, max_objs, 10), (B, max_objs, 10), ...] len = num of SeparateHead
# Transpose inds
inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds] # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
# Transpose inds
masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks] # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
return heatmaps, anno_boxes, inds, masks
def get_targets_single(self, gt_bboxes_3d, gt_labels_3d):
"""Generate training targets for a single sample.
Args:
gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): Ground truth gt boxes. # (N_gt, 7/9)
gt_labels_3d (torch.Tensor): Labels of boxes. # (N_gt, )
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- heatmaps: list[torch.Tensor]: Heatmap scores. # List[(N_cls0, H, W), (N_cls1, H, W), ...]
len = num of tasks
- anno_boxes: list[torch.Tensor]: Ground truth boxes. # List[(max_objs, 10), (max_objs, 10), ...]
- inds: list[torch.Tensor]: Indexes indicating the position
of the valid boxes. # List[(max_objs, ), (max_objs, ), ...]
- masks: list[torch.Tensor]: Masks indicating which boxes
are valid. # List[(max_objs, ), (max_objs, ), ...]
"""
device = gt_labels_3d.device
gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]),
dim=1).to(device) # (N_gt, 7/9)
max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg']
grid_size = torch.tensor(self.train_cfg['grid_size']) # (Dx, Dy, Dz)
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor'] # (W, H)
# reorganize the gt_dict by tasks
task_masks = []
flag = 0
for class_name in self.class_names:
# class_name: 不同task(SeparateHead)负责检测的类别名.
task_masks.append([
torch.where(gt_labels_3d == class_name.index(i) + flag)
for i in class_name
])
flag += len(class_name)
# task_masks: List[task_mask0, task_mask1, ...] len = number of SeparateHeads
# task_mask: List[((N_gt0, ), ), ((N_gt1, ), ), ...] len = number of class
task_boxes = []
task_classes = []
flag2 = 0
for idx, mask in enumerate(task_masks):
# mask: 不同task(SeparateHead)的mask, 每个task负责检测一组不同类别的目标.
# List[((N_gt0, ), ), ((N_gt1, ), ), ...], # N_gt_task=N_gt0+N_gt1+..., 表示当前task负责检测的gt_boxes的数量.
task_box = []
task_class = []
for m in mask:
task_box.append(gt_bboxes_3d[m])
# 0 is background for each task, so we need to add 1 here.
task_class.append(gt_labels_3d[m] + 1 - flag2)
task_boxes.append(torch.cat(task_box, axis=0).to(device))
task_classes.append(torch.cat(task_class).long().to(device))
flag2 += len(mask)
# 记录不同task负责检测的gt_boxes和gt_classes:
# task_boxes: List[(N_gt_task0, 7/9), (N_gt_task1, 7/9), ...]
# task_classes: List[(N_gt_task0, ), (N_gt_task1, ), ...]
draw_gaussian = draw_heatmap_gaussian
heatmaps, anno_boxes, inds, masks = [], [], [], []
for idx, task_head in enumerate(self.task_heads):
heatmap = gt_bboxes_3d.new_zeros(
(len(self.class_names[idx]), feature_map_size[1],
feature_map_size[0])) # (N_cls, H, W) N_cls表示当前task_head负责检测的类别数目.
if self.with_velocity:
anno_box = gt_bboxes_3d.new_zeros((max_objs, 10),
dtype=torch.float32) # (max_objs, 10)
else:
anno_box = gt_bboxes_3d.new_zeros((max_objs, 8),
dtype=torch.float32)
ind = gt_labels_3d.new_zeros((max_objs, ), dtype=torch.int64) # (max_objs, )
mask = gt_bboxes_3d.new_zeros((max_objs, ), dtype=torch.uint8) # (max_objs, )
num_objs = min(task_boxes[idx].shape[0], max_objs) # 当前task_head负责检测的目标.
for k in range(num_objs):
cls_id = task_classes[idx][k] - 1 # 当前目标的cls_id, cls_id是相对task group内的.
width = task_boxes[idx][k][3] # dx
length = task_boxes[idx][k][4] # dy
# 当前目标在feature map上的width和length
width = width / voxel_size[0] / self.train_cfg[
'out_size_factor']
length = length / voxel_size[1] / self.train_cfg[
'out_size_factor']
if width > 0 and length > 0:
# 计算gaussian半径
radius = gaussian_radius(
(length, width),
min_overlap=self.train_cfg['gaussian_overlap'])
radius = max(self.train_cfg['min_radius'], int(radius))
# be really careful for the coordinate system of
# your box annotation.
x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][
1], task_boxes[idx][k][2] # 当前目标的中心坐标.
# 计算gt_box中心点在feature map中对应的位置.
coor_x = (
x - pc_range[0]
) / voxel_size[0] / self.train_cfg['out_size_factor']
coor_y = (
y - pc_range[1]
) / voxel_size[1] / self.train_cfg['out_size_factor']
center = torch.tensor([coor_x, coor_y],
dtype=torch.float32,
device=device)
center_int = center.to(torch.int32)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if not (0 <= center_int[0] < feature_map_size[0]
and 0 <= center_int[1] < feature_map_size[1]):
continue
# 根据目标中心点在feature map中对应的位置、高斯半径来设置heatmap.
draw_gaussian(heatmap[cls_id], center_int, radius)
new_idx = k
x, y = center_int[0], center_int[1]
assert (y * feature_map_size[0] + x <
feature_map_size[0] * feature_map_size[1])
# 记录正样本在feature map中的位置.
ind[new_idx] = y * feature_map_size[0] + x
mask[new_idx] = 1
# TODO: support other outdoor dataset
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
if self.norm_bbox:
box_dim = box_dim.log()
if self.with_velocity:
vx, vy = task_boxes[idx][k][7:]
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device), # tx, ty
z.unsqueeze(0), box_dim, # z, log(dx), log(dy), log(dz)
torch.sin(rot).unsqueeze(0), # sin(rot)
torch.cos(rot).unsqueeze(0), # cos(rot)
vx.unsqueeze(0), # vx
vy.unsqueeze(0) # vy
]) # [tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy]
else:
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.sin(rot).unsqueeze(0),
torch.cos(rot).unsqueeze(0)
])
heatmaps.append(heatmap) # append (N_cls, H, W)
anno_boxes.append(anno_box) # append (max_objs, 10)
masks.append(mask) # append (max_objs, )
inds.append(ind) # append (max_objs, )
return heatmaps, anno_boxes, inds, masks
def loss(self, gt_bboxes_3d, gt_labels_3d, preds_dicts, **kwargs):
"""Loss function for CenterHead.
Args:
gt_bboxes_3d (list[:obj:`LiDARInstance3DBoxes`]): Ground
truth gt boxes. # List[(N_gt0, 7/9), (N_gt1, 7/9), ...]
gt_labels_3d (list[torch.Tensor]): Labels of boxes. # List[(N_gt0, ), (N_gt1, ), ...]
preds_dicts (dict): Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
Returns:
dict[str:torch.Tensor]: Loss of heatmap and bbox of each task.
"""
heatmaps, anno_boxes, inds, masks = self.get_targets(
gt_bboxes_3d, gt_labels_3d)
# heatmaps: # List[(B, N_cls0, H, W), (B, N_cls1, H, W), ...] len = num of SeparateHead
# anno_boxes: # List[(B, max_objs, 10), (B, max_objs, 10), ...] len = num of SeparateHead
# inds: # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
# masks: # List[(B, max_objs), (B, max_objs), ...] len = num of SeparateHead
loss_dict = dict()
if not self.task_specific:
loss_dict['loss'] = 0
for task_id, preds_dict in enumerate(preds_dicts):
# task_id: SeparateHead idx
# preds_dict: List[dict0, ...] len = num levels, 对于center_point len = 1
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
# heatmap focal loss
preds_dict[0]['heatmap'] = clip_sigmoid(preds_dict[0]['heatmap'])
num_pos = heatmaps[task_id].eq(1).float().sum().item()
cls_avg_factor = torch.clamp(
reduce_mean(heatmaps[task_id].new_tensor(num_pos)),
min=1).item()
loss_heatmap = self.loss_cls(
preds_dict[0]['heatmap'], # (B, cur_N_cls, H, W)
heatmaps[task_id], # (B, cur_N_cls, H, W)
avg_factor=cls_avg_factor
)
# (B, max_objs, 10) 10: (tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy)
target_box = anno_boxes[task_id]
# reconstruct the anno_box from multiple reg heads
preds_dict[0]['anno_box'] = torch.cat(
(
preds_dict[0]['reg'],
preds_dict[0]['height'],
preds_dict[0]['dim'],
preds_dict[0]['rot'],
preds_dict[0]['vel'],
),
dim=1,
) # (B, 10, H, W) 10: (tx, ty, z, log(dx), log(dy), log(dz), sin(rot), cos(rot), vx, vy)
# Regression loss for dimension, offset, height, rotation
num = masks[task_id].float().sum() # 正样本的数量
ind = inds[task_id] # (B, max_objs)
pred = preds_dict[0]['anno_box'].permute(0, 2, 3, 1).contiguous() # (B, H, W, 10)
pred = pred.view(pred.size(0), -1, pred.size(3)) # (B, H*W, 10)
pred = self._gather_feat(pred, ind) # (B, max_objs, 10)
# (B, max_objs) --> (B, max_objs, 1) --> (B, max_objs, 10)
mask = masks[task_id].unsqueeze(2).expand_as(target_box).float()
num = torch.clamp(
reduce_mean(target_box.new_tensor(num)), min=1e-4).item()
isnotnan = (~torch.isnan(target_box)).float()
mask *= isnotnan # 只监督mask指定的reg预测.
code_weights = self.train_cfg['code_weights']
bbox_weights = mask * mask.new_tensor(code_weights) # 在mask基础上,设置box不同属性的权重. (B, max_objs, 10)
if self.task_specific:
name_list = ['xy', 'z', 'whl', 'yaw', 'vel']
clip_index = [0, 2, 3, 6, 8, 10]
for reg_task_id in range(len(name_list)):
pred_tmp = pred[..., clip_index[reg_task_id]:clip_index[reg_task_id + 1]] # (B, max_objs, K)
target_box_tmp = target_box[..., clip_index[reg_task_id]:clip_index[reg_task_id + 1]] # (B, max_objs, K)
bbox_weights_tmp = bbox_weights[..., clip_index[reg_task_id]:clip_index[reg_task_id + 1]] # (B, max_objs, K)
loss_bbox_tmp = self.loss_bbox(
pred_tmp,
target_box_tmp,
bbox_weights_tmp,
avg_factor=(num + 1e-4))
loss_dict[f'task{task_id}.loss_%s' %
(name_list[reg_task_id])] = loss_bbox_tmp * self.task_specific_weight[reg_task_id]
loss_dict[f'task{task_id}.loss_heatmap'] = loss_heatmap
else:
loss_bbox = self.loss_bbox(
pred, target_box, bbox_weights, avg_factor=num)
loss_dict['loss'] += loss_bbox
loss_dict['loss'] += loss_heatmap
return loss_dict
def get_bboxes(self, preds_dicts, img_metas, img=None, rescale=False):
"""Generate bboxes from bbox head predictions.
Args:
preds_dicts (tuple[list[dict]]): Prediction results.
Tuple(
List[ret_dict_task0_level0, ...], len = num_levels = 1
List[ret_dict_task1_level0, ...],
...
), len = SeparateHead的数量, 负责预测指定类别的目标.
ret_dict: {
reg: (B, 2, H, W)
height: (B, 1, H, W)
dim: (B, 3, H, W)
rot: (B, 2, H, W)
vel: (B, 2, H, W)
heatmap: (B, n_cls, H, W)
}
img_metas (list[dict]): Point cloud and image's meta info.
Returns:
list[dict]: Decoded bbox, scores and labels after nms.
ret_list: List[p_list0, p_list1, ...]
p_list: List[(N, 9), (N, ), (N, )]
"""
rets = []
for task_id, preds_dict in enumerate(preds_dicts):
# task_id: SeparateHead idx
# preds_dict: List[dict0, ...] len = num levels, 对于center_point len = 1
# dict: {
# reg: (B, 2, H, W)
# height: (B, 1, H, W)
# dim: (B, 3, H, W)
# rot: (B, 2, H, W)
# vel: (B, 2, H, W)
# heatmap: (B, n_cls, H, W)
# }
batch_size = preds_dict[0]['heatmap'].shape[0]
batch_heatmap = preds_dict[0]['heatmap'].sigmoid() # (B, n_cls, H, W)
batch_reg = preds_dict[0]['reg'] # (B, 2, H, W)
batch_hei = preds_dict[0]['height'] # (B, 1, H, W)
if self.norm_bbox:
batch_dim = torch.exp(preds_dict[0]['dim']) # (B, 3, H, W)
else:
batch_dim = preds_dict[0]['dim']
batch_rots = preds_dict[0]['rot'][:, 0].unsqueeze(1) # (B, 1, H, W)
batch_rotc = preds_dict[0]['rot'][:, 1].unsqueeze(1) # (B, 1, H, W)
if 'vel' in preds_dict[0]:
batch_vel = preds_dict[0]['vel'] # (B, 2, H, W)
else:
batch_vel = None
temp = self.bbox_coder.decode(
batch_heatmap,
batch_rots,
batch_rotc,
batch_hei,
batch_dim,
batch_vel,
reg=batch_reg,
task_id=task_id)
# temp: List[p_dict0, p_dict1, ...] len=bs
# p_dict = {
# 'bboxes': boxes3d, # (K', 9)
# 'scores': scores, # (K', )
# 'labels': labels # (K', )
# }
batch_reg_preds = [box['bboxes'] for box in temp] # List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_preds = [box['scores'] for box in temp] # List[(K0, ), (K1, ), ...] len = bs
batch_cls_labels = [box['labels'] for box in temp] # List[(K0, ), (K1, ), ...] len = bs
nms_type = self.test_cfg.get('nms_type')
if isinstance(nms_type, list):
nms_type = nms_type[task_id]
if nms_type == 'circle':
ret_task = []
for i in range(batch_size):
boxes3d = temp[i]['bboxes']
scores = temp[i]['scores']
labels = temp[i]['labels']
centers = boxes3d[:, [0, 1]]
boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
keep = torch.tensor(
circle_nms(
boxes.detach().cpu().numpy(),
self.test_cfg['min_radius'][task_id],
post_max_size=self.test_cfg['post_max_size']),
dtype=torch.long,
device=boxes.device)
boxes3d = boxes3d[keep]
scores = scores[keep]
labels = labels[keep]
ret = dict(bboxes=boxes3d, scores=scores, labels=labels)
ret_task.append(ret)
rets.append(ret_task)
else:
rets.append(
self.get_task_detections(batch_cls_preds, batch_reg_preds,
batch_cls_labels, img_metas,
task_id))
# rets: List[ret_task0, ret_task1, ...], len = num_tasks
# ret_task: List[p_dict0, p_dict1, ...], len = batch_size
# p_dict: dict{
# bboxes: (K', 9)
# scores: (K', )
# labels: (K', )
# }
# Merge branches results
num_samples = len(rets[0]) # bs
ret_list = []
# 遍历batch, 然后汇总所有task的预测.
for i in range(num_samples):
for k in rets[0][i].keys():
if k == 'bboxes':
bboxes = torch.cat([ret[i][k] for ret in rets]) # 对于bboxes, 直接拼接即可.
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = img_metas[i]['box_type_3d'](
bboxes, self.bbox_coder.code_size)
elif k == 'scores':
scores = torch.cat([ret[i][k] for ret in rets]) # 对于scores, 直接拼接即可.
elif k == 'labels':
flag = 0
for j, num_class in enumerate(self.num_classes): # 对于labels, 要进行调整, 因为预测的label是task组内的.
rets[j][i][k] += flag
flag += num_class
labels = torch.cat([ret[i][k].int() for ret in rets])
ret_list.append([bboxes, scores, labels])
# ret_list: List[p_list0, p_list1, ...]
# p_list: List[(N, 9), (N, ), (N, )]
return ret_list
def _nms(self, heat, kernel=3):
pad = (kernel - 1) // 2
hmax = nn.functional.max_pool2d(
heat, (kernel, kernel), stride=1, padding=pad)
keep = (hmax == heat).float()
return heat * keep
def get_centers(self, preds_dicts, img_metas, img=None, rescale=False):
rets = []
for task_id, preds_dict in enumerate(preds_dicts):
batch_size = preds_dict[0]['heatmap'].shape[0]
batch_heatmap = preds_dict[0]['heatmap'].sigmoid() # (B, n_cls, H, W)
batch_reg = preds_dict[0]['reg'] # (B, 2, H, W)
batch_hei = preds_dict[0]['height'] # (B, 1, H, W)
batch_heatmap = self._nms(batch_heatmap)
temp = self.bbox_coder.center_decode(
batch_heatmap,
batch_hei,
reg=batch_reg,
task_id=task_id)
batch_reg_preds = [box['centers'] for box in temp] # List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_preds = [box['scores'] for box in temp] # List[(K0, ), (K1, ), ...] len = bs
batch_cls_labels = [box['labels'] for box in temp] # List[(K0, ), (K1, ), ...] len = bs
ret_list = [batch_reg_preds, batch_cls_preds, batch_cls_labels]
return ret_list
def get_task_detections(self, batch_cls_preds,
batch_reg_preds, batch_cls_labels, img_metas,
task_id):
"""Rotate nms for each task.
Args:
batch_cls_preds (list[torch.Tensor]): Prediction score with the
shape of [N]. # List[(K0, ), (K1, ), ...] len = bs
batch_reg_preds (list[torch.Tensor]): Prediction bbox with the
shape of [N, 9]. # List[(K0, 9), (K1, 9), ...] len = bs
batch_cls_labels (list[torch.Tensor]): Prediction label with the
shape of [N]. # List[(K0, ), (K1, ), ...] len = bs
img_metas (list[dict]): Meta information of each sample.
Returns:
list[dict[str: torch.Tensor]]: contains the following keys:
-bboxes (torch.Tensor): Prediction bboxes after nms with the
shape of [N, 9].
-scores (torch.Tensor): Prediction scores after nms with the
shape of [N].
-labels (torch.Tensor): Prediction labels after nms with the
shape of [N].
List[p_dict0, p_dict1, ...] len = batch_size
p_dict: dict{
bboxes: (K', 9)
scores: (K', )
labels: (K', )
}
"""
predictions_dicts = []
# 遍历不同batch的topK预测输出.
for i, (box_preds, cls_preds, cls_labels) in enumerate(
zip(batch_reg_preds, batch_cls_preds, batch_cls_labels)):
# box_preds: (K, 9)
# cls_preds: (K, )
# cls_labels: (K, )
default_val = [1.0 for _ in range(len(self.task_heads))]
factor = self.test_cfg.get('nms_rescale_factor',
default_val)[task_id]
if isinstance(factor, list):
# List[float, float, ..] len = 当前task负责预测的类别数.
# 对于box_preds, 使用其对应的factor进行缩放, 一般是放大小目标,缩小大目标.
for cid in range(len(factor)):
box_preds[cls_labels == cid, 3:6] = \
box_preds[cls_labels == cid, 3:6] * factor[cid]
else:
box_preds[:, 3:6] = box_preds[:, 3:6] * factor
# Apply NMS in birdeye view
top_labels = cls_labels.long() # (K, )
top_scores = cls_preds.squeeze(-1) if cls_preds.shape[0] > 1 \
else cls_preds # (K, )
if top_scores.shape[0] != 0:
boxes_for_nms = img_metas[i]['box_type_3d'](
box_preds[:, :], self.bbox_coder.code_size).bev # (K, 5) (x, y, dx, dy, yaw)
# the nms in 3d detection just remove overlap boxes.
if isinstance(self.test_cfg['nms_thr'], list):
nms_thresh = self.test_cfg['nms_thr'][task_id]
else:
nms_thresh = self.test_cfg['nms_thr']
selected = nms_bev(
boxes_for_nms,
top_scores,
thresh=nms_thresh,
pre_max_size=self.test_cfg['pre_max_size'],
post_max_size=self.test_cfg['post_max_size'],
xyxyr2xywhr=False,
)
else:
selected = []
# NMS后再根据factor缩放回原来的尺寸.
if isinstance(factor, list):
for cid in range(len(factor)):
box_preds[top_labels == cid, 3:6] = \
box_preds[top_labels == cid, 3:6] / factor[cid]
else:
box_preds[:, 3:6] = box_preds[:, 3:6] / factor
# if selected is not None:
selected_boxes = box_preds[selected] # (K', 9)
selected_labels = top_labels[selected] # (K', )
selected_scores = top_scores[selected] # (K', )
# finally generate predictions.
if selected_boxes.shape[0] != 0:
predictions_dict = dict(
bboxes=selected_boxes,
scores=selected_scores,
labels=selected_labels)
else:
dtype = batch_reg_preds[0].dtype
device = batch_reg_preds[0].device
predictions_dict = dict(
bboxes=torch.zeros([0, self.bbox_coder.code_size],
dtype=dtype,
device=device),
scores=torch.zeros([0], dtype=dtype, device=device),
labels=torch.zeros([0],
dtype=top_labels.dtype,
device=device))
predictions_dicts.append(predictions_dict)
return predictions_dicts
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
import numpy as np
from mmdet3d.models.builder import HEADS, build_loss
from ..losses.semkitti_loss import sem_scal_loss, geo_scal_loss
from ..losses.lovasz_softmax import lovasz_softmax
nusc_class_frequencies = np.array([
944004,
1897170,
152386,
2391677,
16957802,
724139,
189027,
2074468,
413451,
2384460,
5916653,
175883646,
4275424,
51393615,
61411620,
105975596,
116424404,
1892500630
])
@HEADS.register_module()
class BEVOCCHead3D(BaseModule):
def __init__(self,
in_dim=32,
out_dim=32,
use_mask=True,
num_classes=18,
use_predicter=True,
class_balance=False,
loss_occ=None
):
super(BEVOCCHead3D, self).__init__()
self.out_dim = 32
out_channels = out_dim if use_predicter else num_classes
self.final_conv = ConvModule(
in_dim,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
conv_cfg=dict(type='Conv3d')
)
self.use_predicter = use_predicter
if use_predicter:
self.predicter = nn.Sequential(
nn.Linear(self.out_dim, self.out_dim*2),
nn.Softplus(),
nn.Linear(self.out_dim*2, num_classes),
)
self.num_classes = num_classes
self.use_mask = use_mask
self.class_balance = class_balance
if self.class_balance:
class_weights = torch.from_numpy(1 / np.log(nusc_class_frequencies[:num_classes] + 0.001))
self.cls_weights = class_weights
loss_occ['class_weight'] = class_weights
self.loss_occ = build_loss(loss_occ)
def forward(self, img_feats):
"""
Args:
img_feats: (B, C, Dz, Dy, Dx)
Returns:
"""
# (B, C, Dz, Dy, Dx) --> (B, C, Dz, Dy, Dx) --> (B, Dx, Dy, Dz, C)
occ_pred = self.final_conv(img_feats).permute(0, 4, 3, 2, 1)
if self.use_predicter:
# (B, Dx, Dy, Dz, C) --> (B, Dx, Dy, Dz, 2*C) --> (B, Dx, Dy, Dz, n_cls)
occ_pred = self.predicter(occ_pred)
return occ_pred
def loss(self, occ_pred, voxel_semantics, mask_camera):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, n_cls)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
loss = dict()
voxel_semantics = voxel_semantics.long()
if self.use_mask:
mask_camera = mask_camera.to(torch.int32) # (B, Dx, Dy, Dz)
# (B, Dx, Dy, Dz) --> (B*Dx*Dy*Dz, )
voxel_semantics = voxel_semantics.reshape(-1)
# (B, Dx, Dy, Dz, n_cls) --> (B*Dx*Dy*Dz, n_cls)
preds = occ_pred.reshape(-1, self.num_classes)
# (B, Dx, Dy, Dz) --> (B*Dx*Dy*Dz, )
mask_camera = mask_camera.reshape(-1)
if self.class_balance:
valid_voxels = voxel_semantics[mask_camera.bool()]
num_total_samples = 0
for i in range(self.num_classes):
num_total_samples += (valid_voxels == i).sum() * self.cls_weights[i]
else:
num_total_samples = mask_camera.sum()
loss_occ = self.loss_occ(
preds, # (B*Dx*Dy*Dz, n_cls)
voxel_semantics, # (B*Dx*Dy*Dz, )
mask_camera, # (B*Dx*Dy*Dz, )
avg_factor=num_total_samples
)
else:
voxel_semantics = voxel_semantics.reshape(-1)
preds = occ_pred.reshape(-1, self.num_classes)
if self.class_balance:
num_total_samples = 0
for i in range(self.num_classes):
num_total_samples += (voxel_semantics == i).sum() * self.cls_weights[i]
else:
num_total_samples = len(voxel_semantics)
loss_occ = self.loss_occ(
preds,
voxel_semantics,
avg_factor=num_total_samples
)
loss['loss_occ'] = loss_occ
return loss
def get_occ(self, occ_pred, img_metas=None):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, C)
img_metas:
Returns:
List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
occ_score = occ_pred.softmax(-1) # (B, Dx, Dy, Dz, C)
occ_res = occ_score.argmax(-1) # (B, Dx, Dy, Dz)
occ_res = occ_res.cpu().numpy().astype(np.uint8) # (B, Dx, Dy, Dz)
return list(occ_res)
@HEADS.register_module()
class BEVOCCHead2D(BaseModule):
def __init__(self,
in_dim=256,
out_dim=256,
Dz=16,
use_mask=True,
num_classes=18,
use_predicter=True,
class_balance=False,
loss_occ=None,
):
super(BEVOCCHead2D, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.Dz = Dz
out_channels = out_dim if use_predicter else num_classes * Dz
self.final_conv = ConvModule(
self.in_dim,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
conv_cfg=dict(type='Conv2d')
)
self.use_predicter = use_predicter
if use_predicter:
self.predicter = nn.Sequential(
nn.Linear(self.out_dim, self.out_dim * 2),
nn.Softplus(),
nn.Linear(self.out_dim * 2, num_classes * Dz),
)
self.use_mask = use_mask
self.num_classes = num_classes
self.class_balance = class_balance
if self.class_balance:
class_weights = torch.from_numpy(1 / np.log(nusc_class_frequencies[:num_classes] + 0.001))
self.cls_weights = class_weights
loss_occ['class_weight'] = class_weights # ce loss
self.loss_occ = build_loss(loss_occ)
def forward(self, img_feats):
"""
Args:
img_feats: (B, C, Dy, Dx)
Returns:
"""
# (B, C, Dy, Dx) --> (B, C, Dy, Dx) --> (B, Dx, Dy, C)
occ_pred = self.final_conv(img_feats).permute(0, 3, 2, 1)
bs, Dx, Dy = occ_pred.shape[:3]
if self.use_predicter:
# (B, Dx, Dy, C) --> (B, Dx, Dy, 2*C) --> (B, Dx, Dy, Dz*n_cls)
occ_pred = self.predicter(occ_pred)
occ_pred = occ_pred.view(bs, Dx, Dy, self.Dz, self.num_classes)
return occ_pred
def loss(self, occ_pred, voxel_semantics, mask_camera):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, n_cls)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
loss = dict()
voxel_semantics = voxel_semantics.long()
if self.use_mask:
mask_camera = mask_camera.to(torch.int32) # (B, Dx, Dy, Dz)
# (B, Dx, Dy, Dz) --> (B*Dx*Dy*Dz, )
voxel_semantics = voxel_semantics.reshape(-1)
# (B, Dx, Dy, Dz, n_cls) --> (B*Dx*Dy*Dz, n_cls)
preds = occ_pred.reshape(-1, self.num_classes)
# (B, Dx, Dy, Dz) --> (B*Dx*Dy*Dz, )
mask_camera = mask_camera.reshape(-1)
if self.class_balance:
valid_voxels = voxel_semantics[mask_camera.bool()]
num_total_samples = 0
for i in range(self.num_classes):
num_total_samples += (valid_voxels == i).sum() * self.cls_weights[i]
else:
num_total_samples = mask_camera.sum()
loss_occ = self.loss_occ(
preds, # (B*Dx*Dy*Dz, n_cls)
voxel_semantics, # (B*Dx*Dy*Dz, )
mask_camera, # (B*Dx*Dy*Dz, )
avg_factor=num_total_samples
)
loss['loss_occ'] = loss_occ
else:
voxel_semantics = voxel_semantics.reshape(-1)
preds = occ_pred.reshape(-1, self.num_classes)
if self.class_balance:
num_total_samples = 0
for i in range(self.num_classes):
num_total_samples += (voxel_semantics == i).sum() * self.cls_weights[i]
else:
num_total_samples = len(voxel_semantics)
loss_occ = self.loss_occ(
preds,
voxel_semantics,
avg_factor=num_total_samples
)
loss['loss_occ'] = loss_occ
return loss
def get_occ(self, occ_pred, img_metas=None):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, C)
img_metas:
Returns:
List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
occ_score = occ_pred.softmax(-1) # (B, Dx, Dy, Dz, C)
occ_res = occ_score.argmax(-1) # (B, Dx, Dy, Dz)
occ_res = occ_res.cpu().numpy().astype(np.uint8) # (B, Dx, Dy, Dz)
return list(occ_res)
@HEADS.register_module()
class BEVOCCHead2D_V2(BaseModule): # Use stronger loss setting
def __init__(self,
in_dim=256,
out_dim=256,
Dz=16,
use_mask=True,
num_classes=18,
use_predicter=True,
class_balance=False,
loss_occ=None,
):
super(BEVOCCHead2D_V2, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.Dz = Dz
# voxel-level prediction
self.occ_convs = nn.ModuleList()
self.final_conv = ConvModule(
in_dim,
self.out_dim,
kernel_size=3,
stride=1,
padding=1,
bias=True,
conv_cfg=dict(type='Conv2d')
)
self.use_predicter = use_predicter
if use_predicter:
self.predicter = nn.Sequential(
nn.Linear(self.out_dim, self.out_dim * 2),
nn.Softplus(),
nn.Linear(self.out_dim * 2, num_classes * Dz),
)
self.use_mask = use_mask
self.num_classes = num_classes
self.class_balance = class_balance
if self.class_balance:
class_weights = torch.from_numpy(1 / np.log(nusc_class_frequencies[:num_classes] + 0.001))
self.cls_weights = class_weights
self.loss_occ = build_loss(loss_occ)
def forward(self, img_feats):
"""
Args:
img_feats: (B, C, Dy=200, Dx=200)
img_feats: [(B, C, 100, 100), (B, C, 50, 50), (B, C, 25, 25)] if ms
Returns:
"""
# (B, C, Dy, Dx) --> (B, C, Dy, Dx) --> (B, Dx, Dy, C)
occ_pred = self.final_conv(img_feats).permute(0, 3, 2, 1)
bs, Dx, Dy = occ_pred.shape[:3]
if self.use_predicter:
# (B, Dx, Dy, C) --> (B, Dx, Dy, 2*C) --> (B, Dx, Dy, Dz*n_cls)
occ_pred = self.predicter(occ_pred)
occ_pred = occ_pred.view(bs, Dx, Dy, self.Dz, self.num_classes)
return occ_pred
def loss(self, occ_pred, voxel_semantics, mask_camera):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, n_cls)
voxel_semantics: (B, Dx, Dy, Dz)
mask_camera: (B, Dx, Dy, Dz)
Returns:
"""
loss = dict()
voxel_semantics = voxel_semantics.long() # (B, Dx, Dy, Dz)
preds = occ_pred.permute(0, 4, 1, 2, 3).contiguous() # (B, n_cls, Dx, Dy, Dz)
loss_occ = self.loss_occ(
preds,
voxel_semantics,
weight=self.cls_weights.to(preds),
) * 100.0
loss['loss_occ'] = loss_occ
loss['loss_voxel_sem_scal'] = sem_scal_loss(preds, voxel_semantics)
loss['loss_voxel_geo_scal'] = geo_scal_loss(preds, voxel_semantics, non_empty_idx=17)
loss['loss_voxel_lovasz'] = lovasz_softmax(torch.softmax(preds, dim=1), voxel_semantics)
return loss
def get_occ(self, occ_pred, img_metas=None):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, C)
img_metas:
Returns:
List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
occ_score = occ_pred.softmax(-1) # (B, Dx, Dy, Dz, C)
occ_res = occ_score.argmax(-1) # (B, Dx, Dy, Dz)
occ_res = occ_res.cpu().numpy().astype(np.uint8) # (B, Dx, Dy, Dz)
return list(occ_res)
def get_occ_gpu(self, occ_pred, img_metas=None):
"""
Args:
occ_pred: (B, Dx, Dy, Dz, C)
img_metas:
Returns:
List[(Dx, Dy, Dz), (Dx, Dy, Dz), ...]
"""
occ_score = occ_pred.softmax(-1) # (B, Dx, Dy, Dz, C)
occ_res = occ_score.argmax(-1).int() # (B, Dx, Dy, Dz)
return list(occ_res)
\ No newline at end of file
from .bevdet import BEVDet
from .bevdepth import BEVDepth
from .bevdet4d import BEVDet4D
from .bevdepth4d import BEVDepth4D
from .bevstereo4d import BEVStereo4D
from .bevdet_occ import BEVDetOCC, BEVDepthOCC, BEVDepth4DOCC, BEVStereo4DOCC, BEVDepth4DPano, BEVDepthPano, BEVDepthPanoTRT
__all__ = ['BEVDet', 'BEVDepth', 'BEVDet4D', 'BEVDepth4D', 'BEVStereo4D', 'BEVDetOCC', 'BEVDepthOCC',
'BEVDepth4DOCC', 'BEVStereo4DOCC', 'BEVDepthPano', 'BEVDepth4DPano', 'BEVDepthPanoTRT']
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment