Unverified Commit 52fe5baa authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Feature] Add TPVFormer in `Projects` (#2399)

* fix polarmix UT

* init tpvformer

* add nus seg

* add nus seg

* test done

* Delete change_key.py

* Delete test_dcn.py

* remove seg eval

* fix encoder

* init train

* train ready

* remove asynctest

* change test.yml

* pr_stage_test.yml & merge_stage_test.yml

* pip install wheel

* pip install wheel all

* check type hint

* check comments

* remove Photo aug

* fix p2v

* fix docsting & fix config filepath
parent 106b17e7
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .cylinder3d_head import Cylinder3DHead from .cylinder3d_head import Cylinder3DHead
from .decode_head import Base3DDecodeHead
from .dgcnn_head import DGCNNHead from .dgcnn_head import DGCNNHead
from .minkunet_head import MinkUNetHead from .minkunet_head import MinkUNetHead
from .paconv_head import PAConvHead from .paconv_head import PAConvHead
...@@ -7,5 +8,5 @@ from .pointnet2_head import PointNet2Head ...@@ -7,5 +8,5 @@ from .pointnet2_head import PointNet2Head
__all__ = [ __all__ = [
'PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead', 'PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead',
'MinkUNetHead' 'Base3DDecodeHead', 'MinkUNetHead'
] ]
_base_ = ['mmdet3d::_base_/default_runtime.py'] _base_ = ['../../../configs/_base_/default_runtime.py']
custom_imports = dict( custom_imports = dict(
imports=['projects.CenterFormer.centerformer'], allow_failed_imports=False) imports=['projects.CenterFormer.centerformer'], allow_failed_imports=False)
......
_base_ = [ _base_ = [
# 'mmdet3d::_base_/datasets/nus-3d.py', # 'mmdet3d::_base_/datasets/nus-3d.py',
'mmdet3d::_base_/default_runtime.py' '../../../configs/_base_/default_runtime.py'
] ]
custom_imports = dict(imports=['projects.DETR3D.detr3d']) custom_imports = dict(imports=['projects.DETR3D.detr3d'])
......
_base_ = [ _base_ = [
'mmdet3d::_base_/datasets/nus-3d.py', 'mmdet3d::_base_/default_runtime.py', '../../../configs/_base_/datasets/nus-3d.py',
'mmdet3d::_base_/schedules/cyclic-20e.py' '../../../configs/_base_/default_runtime.py',
'../../../configs/_base_/schedules/cyclic-20e.py'
] ]
backbone_norm_cfg = dict(type='LN', requires_grad=True) backbone_norm_cfg = dict(type='LN', requires_grad=True)
custom_imports = dict(imports=['projects.PETR.petr']) custom_imports = dict(imports=['projects.PETR.petr'])
......
_base_ = ['../../../configs/_base_/default_runtime.py']
custom_imports = dict(
imports=['projects.TPVFormer.tpvformer'], allow_failed_imports=False)
dataset_type = 'NuScenesSegDataset'
data_root = 'data/nuscenes/'
data_prefix = dict(
pts='samples/LIDAR_TOP',
pts_semantic_mask='lidarseg/v1.0-trainval',
CAM_FRONT='samples/CAM_FRONT',
CAM_FRONT_LEFT='samples/CAM_FRONT_LEFT',
CAM_FRONT_RIGHT='samples/CAM_FRONT_RIGHT',
CAM_BACK='samples/CAM_BACK',
CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT',
CAM_BACK_LEFT='samples/CAM_BACK_LEFT')
backend_args = None
train_pipeline = [
dict(
type='BEVLoadMultiViewImageFromFiles',
to_float32=False,
color_type='unchanged',
num_views=6,
backend_args=backend_args),
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5,
use_dim=3,
backend_args=backend_args),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_seg_3d=True,
with_attr_label=False,
seg_3d_dtype='np.uint8'),
dict(
type='MultiViewWrapper',
transforms=dict(type='PhotoMetricDistortion3D')),
dict(type='SegLabelMapping'),
dict(
type='Pack3DDetInputs',
keys=['img', 'points', 'pts_semantic_mask'],
meta_keys=['lidar2img'])
]
val_pipeline = [
dict(
type='BEVLoadMultiViewImageFromFiles',
to_float32=False,
color_type='unchanged',
num_views=6,
backend_args=backend_args),
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5,
use_dim=3,
backend_args=backend_args),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_seg_3d=True,
with_attr_label=False,
seg_3d_dtype='np.uint8'),
dict(type='SegLabelMapping'),
dict(
type='Pack3DDetInputs',
keys=['img', 'points', 'pts_semantic_mask'],
meta_keys=['lidar2img'])
]
test_pipeline = val_pipeline
train_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
drop_last=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=data_prefix,
ann_file='nuscenes_infos_train.pkl',
pipeline=train_pipeline,
test_mode=False))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=data_prefix,
ann_file='nuscenes_infos_val.pkl',
pipeline=val_pipeline,
test_mode=True))
test_dataloader = val_dataloader
val_evaluator = dict(type='SegMetric')
test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=2e-4, weight_decay=0.01),
paramwise_cfg=dict(custom_keys={
'backbone': dict(lr_mult=0.1),
}),
clip_grad=dict(max_norm=35, norm_type=2),
)
param_scheduler = [
dict(type='LinearLR', start_factor=1e-5, by_epoch=False, begin=0, end=500),
dict(
type='CosineAnnealingLR',
begin=0,
T_max=24,
by_epoch=True,
eta_min=1e-6,
convert_to_iter_based=True)
]
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=24, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1))
point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
_dim_ = 128
num_heads = 8
_ffn_dim_ = _dim_ * 2
tpv_h_ = 200
tpv_w_ = 200
tpv_z_ = 16
scale_h = 1
scale_w = 1
scale_z = 1
num_points_in_pillar = [4, 32, 32]
num_points = [8, 64, 64]
hybrid_attn_anchors = 16
hybrid_attn_points = 32
hybrid_attn_init = 0
grid_shape = [tpv_h_ * scale_h, tpv_w_ * scale_w, tpv_z_ * scale_z]
self_cross_layer = dict(
type='TPVFormerLayer',
attn_cfgs=[
dict(
type='TPVCrossViewHybridAttention',
tpv_h=tpv_h_,
tpv_w=tpv_w_,
tpv_z=tpv_z_,
num_anchors=hybrid_attn_anchors,
embed_dims=_dim_,
num_heads=num_heads,
num_points=hybrid_attn_points,
init_mode=hybrid_attn_init,
dropout=0.1),
dict(
type='TPVImageCrossAttention',
pc_range=point_cloud_range,
num_cams=6,
dropout=0.1,
deformable_attention=dict(
type='TPVMSDeformableAttention3D',
embed_dims=_dim_,
num_heads=num_heads,
num_points=num_points,
num_z_anchors=num_points_in_pillar,
num_levels=4,
floor_sampling_offset=False,
tpv_h=tpv_h_,
tpv_w=tpv_w_,
tpv_z=tpv_z_),
embed_dims=_dim_,
tpv_h=tpv_h_,
tpv_w=tpv_w_,
tpv_z=tpv_z_)
],
feedforward_channels=_ffn_dim_,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'))
self_layer = dict(
type='TPVFormerLayer',
attn_cfgs=[
dict(
type='TPVCrossViewHybridAttention',
tpv_h=tpv_h_,
tpv_w=tpv_w_,
tpv_z=tpv_z_,
num_anchors=hybrid_attn_anchors,
embed_dims=_dim_,
num_heads=num_heads,
num_points=hybrid_attn_points,
init_mode=hybrid_attn_init,
dropout=0.1)
],
feedforward_channels=_ffn_dim_,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))
model = dict(
type='TPVFormer',
data_preprocessor=dict(
type='TPVFormerDataPreprocessor',
pad_size_divisor=32,
mean=[103.530, 116.280, 123.675],
std=[1.0, 1.0, 1.0],
voxel=True,
voxel_type='cylindrical',
voxel_layer=dict(
grid_shape=grid_shape,
point_cloud_range=point_cloud_range,
max_num_points=-1,
max_voxels=-1,
),
batch_augments=[
dict(
type='GridMask',
use_h=True,
use_w=True,
rotate=1,
offset=False,
ratio=0.5,
mode=1,
prob=0.7)
]),
backbone=dict(
type='mmdet.ResNet',
depth=101,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN2d', requires_grad=False),
norm_eval=True,
style='caffe',
dcn=dict(
type='DCNv2', deform_groups=1, fallback_on_stride=False
), # original DCNv2 will print log when perform load_state_dict
stage_with_dcn=(False, False, True, True),
init_cfg=dict(
type='Pretrained',
checkpoint='checkpoints/tpvformer_r101_dcn_fcos3d_pretrain.pth',
prefix='backbone.')),
neck=dict(
type='mmdet.FPN',
in_channels=[512, 1024, 2048],
out_channels=_dim_,
start_level=0,
add_extra_convs='on_output',
num_outs=4,
relu_before_extra_convs=True,
init_cfg=dict(
type='Pretrained',
checkpoint='checkpoints/tpvformer_r101_dcn_fcos3d_pretrain.pth',
prefix='neck.')),
encoder=dict(
type='TPVFormerEncoder',
tpv_h=tpv_h_,
tpv_w=tpv_w_,
tpv_z=tpv_z_,
num_layers=5,
pc_range=point_cloud_range,
num_points_in_pillar=num_points_in_pillar,
num_points_in_pillar_cross_view=[16, 16, 16],
return_intermediate=False,
transformerlayers=[
self_cross_layer, self_cross_layer, self_cross_layer, self_layer,
self_layer
],
embed_dims=_dim_,
positional_encoding=dict(
type='TPVFormerPositionalEncoding',
num_feats=[48, 48, 32],
h=tpv_h_,
w=tpv_w_,
z=tpv_z_)),
decode_head=dict(
type='TPVFormerDecoder',
tpv_h=tpv_h_,
tpv_w=tpv_w_,
tpv_z=tpv_z_,
num_classes=17,
in_dims=_dim_,
hidden_dims=2 * _dim_,
out_dims=_dim_,
scale_h=scale_h,
scale_w=scale_w,
scale_z=scale_z,
loss_ce=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
avg_non_ignore=True,
loss_weight=1.0),
loss_lovasz=dict(type='LovaszLoss', loss_weight=1.0, reduction='none'),
lovasz_input='points',
ce_input='voxel',
ignore_index=0))
from .cross_view_hybrid_attention import TPVCrossViewHybridAttention
from .data_preprocessor import TPVFormerDataPreprocessor
from .image_cross_attention import TPVImageCrossAttention
from .loading import BEVLoadMultiViewImageFromFiles, SegLabelMapping
from .nuscenes_dataset import NuScenesSegDataset
from .positional_encoding import TPVFormerPositionalEncoding
from .tpvformer import TPVFormer
from .tpvformer_encoder import TPVFormerEncoder
from .tpvformer_head import TPVFormerDecoder
from .tpvformer_layer import TPVFormerLayer
__all__ = [
'TPVCrossViewHybridAttention', 'TPVImageCrossAttention',
'TPVFormerPositionalEncoding', 'TPVFormer', 'TPVFormerEncoder',
'TPVFormerLayer', 'NuScenesSegDataset', 'BEVLoadMultiViewImageFromFiles',
'SegLabelMapping', 'TPVFormerDecoder', 'TPVFormerDataPreprocessor'
]
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
from mmengine.model import BaseModule, constant_init, xavier_init
from torch import Tensor
from mmdet3d.registry import MODELS
@MODELS.register_module()
class TPVCrossViewHybridAttention(BaseModule):
"""TPVFormer Cross-view Hybrid Attention Module."""
def __init__(self,
tpv_h: int,
tpv_w: int,
tpv_z: int,
embed_dims: int = 256,
num_heads: int = 8,
num_points: int = 4,
num_anchors: int = 2,
init_mode: int = 0,
dropout: float = 0.1,
**kwargs):
super().__init__()
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_levels = 3
self.num_points = num_points
self.num_anchors = num_anchors
self.init_mode = init_mode
self.dropout = nn.ModuleList([nn.Dropout(dropout) for _ in range(3)])
self.output_proj = nn.ModuleList(
[nn.Linear(embed_dims, embed_dims) for _ in range(3)])
self.sampling_offsets = nn.ModuleList([
nn.Linear(embed_dims, num_heads * 3 * num_points * 2)
for _ in range(3)
])
self.attention_weights = nn.ModuleList([
nn.Linear(embed_dims, num_heads * 3 * (num_points + 1))
for _ in range(3)
])
self.value_proj = nn.ModuleList(
[nn.Linear(embed_dims, embed_dims) for _ in range(3)])
self.tpv_h, self.tpv_w, self.tpv_z = tpv_h, tpv_w, tpv_z
def init_weights(self):
"""Default initialization for Parameters of Module."""
device = next(self.parameters()).device
# self plane
theta_self = torch.arange(
self.num_heads, dtype=torch.float32,
device=device) * (2.0 * math.pi / self.num_heads)
grid_self = torch.stack(
[theta_self.cos(), theta_self.sin()], -1) # H, 2
grid_self = grid_self.view(self.num_heads, 1,
2).repeat(1, self.num_points, 1)
for j in range(self.num_points):
grid_self[:, j, :] *= (j + 1) / 2
if self.init_mode == 0:
# num_phi = 4
phi = torch.arange(
4, dtype=torch.float32, device=device) * (2.0 * math.pi / 4)
assert self.num_heads % 4 == 0
num_theta = int(self.num_heads / 4)
theta = torch.arange(
num_theta, dtype=torch.float32, device=device) * (
math.pi / num_theta) + (math.pi / num_theta / 2) # 3
x = torch.matmul(theta.sin().unsqueeze(-1),
phi.cos().unsqueeze(0)).flatten()
y = torch.matmul(theta.sin().unsqueeze(-1),
phi.sin().unsqueeze(0)).flatten()
z = theta.cos().unsqueeze(-1).repeat(1, 4).flatten()
xyz = torch.stack([x, y, z], dim=-1) # H, 3
elif self.init_mode == 1:
xyz = [[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0], [1, 0, 0],
[-1, 0, 0]]
xyz = torch.tensor(xyz, dtype=torch.float32, device=device)
grid_hw = xyz[:, [0, 1]] # H, 2
grid_zh = xyz[:, [2, 0]]
grid_wz = xyz[:, [1, 2]]
for i in range(3):
grid = torch.stack([grid_hw, grid_zh, grid_wz], dim=1) # H, 3, 2
grid = grid.unsqueeze(2).repeat(1, 1, self.num_points, 1)
grid = grid.reshape(self.num_heads, self.num_levels,
self.num_anchors, -1, 2)
for j in range(self.num_points // self.num_anchors):
grid[:, :, :, j, :] *= 2 * (j + 1)
grid = grid.flatten(2, 3)
grid[:, i, :, :] = grid_self
constant_init(self.sampling_offsets[i], 0.)
self.sampling_offsets[i].bias.data = grid.view(-1)
constant_init(self.attention_weights[i], val=0., bias=0.)
attn_bias = torch.zeros(
self.num_heads, 3, self.num_points + 1, device=device)
attn_bias[:, i, -1] = 10
self.attention_weights[i].bias.data = attn_bias.flatten()
xavier_init(self.value_proj[i], distribution='uniform', bias=0.)
xavier_init(self.output_proj[i], distribution='uniform', bias=0.)
def get_sampling_offsets_and_attention(
self, queries: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
offsets = []
attns = []
for i, (query, fc, attn) in enumerate(
zip(queries, self.sampling_offsets, self.attention_weights)):
bs, l, d = query.shape
offset = fc(query).reshape(bs, l, self.num_heads, self.num_levels,
self.num_points, 2)
offsets.append(offset)
attention = attn(query).reshape(bs, l, self.num_heads, 3, -1)
level_attention = attention[:, :, :, :,
-1:].softmax(-2) # bs, l, H, 3, 1
attention = attention[:, :, :, :, :-1]
attention = attention.softmax(-1) # bs, l, H, 3, p
attention = attention * level_attention
attns.append(attention)
offsets = torch.cat(offsets, dim=1)
attns = torch.cat(attns, dim=1)
return offsets, attns
def reshape_output(self, output: Tensor, lens: List[int]) -> List[Tensor]:
outputs = torch.split(output, [lens[0], lens[1], lens[2]], dim=1)
return outputs
def forward(self,
query: List[Tensor],
identity: Optional[List[Tensor]] = None,
query_pos: Optional[List[Tensor]] = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None):
identity = query if identity is None else identity
if query_pos is not None:
query = [q + p for q, p in zip(query, query_pos)]
# value proj
query_lens = [q.shape[1] for q in query]
value = [layer(q) for layer, q in zip(self.value_proj, query)]
value = torch.cat(value, dim=1)
bs, num_value, _ = value.shape
value = value.view(bs, num_value, self.num_heads, -1)
# sampling offsets and weights
sampling_offsets, attention_weights = \
self.get_sampling_offsets_and_attention(query)
if reference_points.shape[-1] == 2:
"""For each tpv query, it owns `num_Z_anchors` in 3D space that
having different heights. After projecting, each tpv query has
`num_Z_anchors` reference points in each 2D image. For each
referent point, we sample `num_points` sampling points.
For `num_Z_anchors` reference points,
it has overall `num_points * num_Z_anchors` sampling points.
"""
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
bs, num_query, _, num_Z_anchors, xy = reference_points.shape
reference_points = reference_points[:, :, None, :, :, None, :]
sampling_offsets = sampling_offsets / \
offset_normalizer[None, None, None, :, None, :]
bs, num_query, num_heads, num_levels, num_all_points, xy = sampling_offsets.shape # noqa
sampling_offsets = sampling_offsets.view(
bs, num_query, num_heads, num_levels, num_Z_anchors,
num_all_points // num_Z_anchors, xy)
sampling_locations = reference_points + sampling_offsets
bs, num_query, num_heads, num_levels, num_points, num_Z_anchors, xy = sampling_locations.shape # noqa
sampling_locations = sampling_locations.view(
bs, num_query, num_heads, num_levels, num_all_points, xy)
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available() and value.is_cuda:
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, 64)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
outputs = self.reshape_output(output, query_lens)
results = []
for out, layer, drop, residual in zip(outputs, self.output_proj,
self.dropout, identity):
results.append(residual + drop(layer(out)))
return results
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torch import Tensor
from torch.nn import functional as F
from mmdet3d.models import Det3DDataPreprocessor
from mmdet3d.models.data_preprocessors.voxelize import dynamic_scatter_3d
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
@MODELS.register_module()
class TPVFormerDataPreprocessor(Det3DDataPreprocessor):
@torch.no_grad()
def voxelize(self, points: List[Tensor],
data_samples: SampleList) -> List[Tensor]:
"""Apply voxelization to point cloud. In TPVFormer, it will get voxel-
wise segmentation label and voxel/point coordinates.
Args:
points (List[Tensor]): Point cloud in one data batch.
data_samples: (List[:obj:`Det3DDataSample`]): The annotation data
of every samples. Add voxel-wise annotation for segmentation.
Returns:
List[Tensor]: Coordinates of voxels, shape is Nx3,
"""
for point, data_sample in zip(points, data_samples):
min_bound = point.new_tensor(
self.voxel_layer.point_cloud_range[:3])
max_bound = point.new_tensor(
self.voxel_layer.point_cloud_range[3:])
point_clamp = torch.clamp(point, min_bound, max_bound + 1e-6)
coors = torch.floor(
(point_clamp - min_bound) /
point_clamp.new_tensor(self.voxel_layer.voxel_size)).int()
self.get_voxel_seg(coors, data_sample)
data_sample.point_coors = coors
def get_voxel_seg(self, res_coors: Tensor, data_sample: SampleList):
"""Get voxel-wise segmentation label and point2voxel map.
Args:
res_coors (Tensor): The voxel coordinates of points, Nx3.
data_sample: (:obj:`Det3DDataSample`): The annotation data of
every samples. Add voxel-wise annotation forsegmentation.
"""
if self.training:
pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
pts_semantic_mask = F.one_hot(pts_semantic_mask.long()).float()
voxel_semantic_mask, voxel_coors, point2voxel_map = \
dynamic_scatter_3d(pts_semantic_mask, res_coors, 'mean', True)
voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1)
data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask
data_sample.point2voxel_map = point2voxel_map
data_sample.voxel_coors = voxel_coors
else:
pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float()
_, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
res_coors, 'mean', True)
data_sample.point2voxel_map = point2voxel_map
@MODELS.register_module()
class GridMask(nn.Module):
"""GridMask data augmentation.
Modified from https://github.com/dvlab-research/GridMask.
Args:
use_h (bool): Whether to mask on height dimension. Defaults to True.
use_w (bool): Whether to mask on width dimension. Defaults to True.
rotate (int): Rotation degree. Defaults to 1.
offset (bool): Whether to mask offset. Defaults to False.
ratio (float): Mask ratio. Defaults to 0.5.
mode (int): Mask mode. if mode == 0, mask with square grid.
if mode == 1, mask the rest. Defaults to 0
prob (float): Probability of applying the augmentation.
Defaults to 1.0.
"""
def __init__(self,
use_h: bool = True,
use_w: bool = True,
rotate: int = 1,
offset: bool = False,
ratio: float = 0.5,
mode: int = 0,
prob: float = 1.0):
super().__init__()
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.prob = prob
def forward(self, inputs: Tensor,
data_samples: SampleList) -> Tuple[Tensor, SampleList]:
if np.random.rand() > self.prob:
return inputs, data_samples
height, width = inputs.shape[-2:]
mask_height = int(1.5 * height)
mask_width = int(1.5 * width)
distance = np.random.randint(2, min(height, width))
length = min(max(int(distance * self.ratio + 0.5), 1), distance - 1)
mask = np.ones((mask_height, mask_width), np.float32)
stride_on_height = np.random.randint(distance)
stride_on_width = np.random.randint(distance)
if self.use_h:
for i in range(mask_height // distance):
start = distance * i + stride_on_height
end = min(start + length, mask_height)
mask[start:end, :] *= 0
if self.use_w:
for i in range(mask_width // distance):
start = distance * i + stride_on_width
end = min(start + length, mask_width)
mask[:, start:end] *= 0
# NOTE: r is the rotation radian, here is a random counterclockwise
# rotation of 1° or remain unchanged, which follows the implementation
# of the official detection version.
# https://github.com/dvlab-research/GridMask.
r = np.random.randint(self.rotate)
mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)
mask = np.array(mask)
mask = mask[int(0.25 * height):int(0.25 * height) + height,
int(0.25 * width):int(0.25 * width) + width]
mask = inputs.new_tensor(mask)
if self.mode == 1:
mask = 1 - mask
mask = mask.expand_as(inputs)
if self.offset:
offset = inputs.new_tensor(2 *
(np.random.rand(height, width) - 0.5))
inputs = inputs * mask + offset * (1 - mask)
else:
inputs = inputs * mask
return inputs, data_samples
import math
import warnings
import torch
import torch.nn as nn
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
from mmengine.model import BaseModule, constant_init, xavier_init
from mmdet3d.registry import MODELS
@MODELS.register_module()
class TPVImageCrossAttention(BaseModule):
"""An attention module used in TPVFormer.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_cams (int): The number of cameras
dropout (float): A Dropout layer on `inp_residual`.
Default: 0.1.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Whether the first dimension of the input is batch.
deformable_attention: (dict): The config for the deformable
attention used in SCA.
tpv_h (int): The height of the TPV.
tpv_w (int): The width of the TPV.
tpv_z (int): The depth of the TPV.
"""
def __init__(self,
embed_dims=256,
num_cams=6,
pc_range=None,
dropout=0.1,
init_cfg=None,
batch_first=True,
deformable_attention=dict(
type='MSDeformableAttention3D',
embed_dims=256,
num_levels=4),
tpv_h=None,
tpv_w=None,
tpv_z=None):
super().__init__(init_cfg)
self.init_cfg = init_cfg
self.dropout = nn.Dropout(dropout)
self.pc_range = pc_range
self.fp16_enabled = False
self.deformable_attention = MODELS.build(deformable_attention)
self.embed_dims = embed_dims
self.num_cams = num_cams
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.batch_first = batch_first
self.tpv_h, self.tpv_w, self.tpv_z = tpv_h, tpv_w, tpv_z
self.init_weight()
def init_weight(self):
"""Default initialization for Parameters of Module."""
xavier_init(self.output_proj, distribution='uniform', bias=0.)
def forward(self,
query,
key,
value,
residual=None,
spatial_shapes=None,
reference_points_cams=None,
tpv_masks=None,
level_start_index=None):
"""Forward Function of Detr3DCrossAtten.
Args:
query (Tensor): Query of Transformer with shape
(bs, num_query, embed_dims).
key (Tensor): The key tensor with shape
(bs, num_key, embed_dims).
value (Tensor): The value tensor with shape
(bs, num_key, embed_dims).
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
spatial_shapes (Tensor): Spatial shape of features in
different level. With shape (num_levels, 2),
last dimension represent (h, w).
tpv_masks (List[Tensor]): The mask of each views.
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
reference_points_cams (List[Tensor]): The reference points in
each camera.
tpv_masks (List[Tensor]): The mask of each views.
level_start_index (List[int]): The start index of each level.
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if residual is None:
inp_residual = query
bs, _, _ = query.size()
queries = torch.split(
query, [
self.tpv_h * self.tpv_w, self.tpv_z * self.tpv_h,
self.tpv_w * self.tpv_z
],
dim=1)
if residual is None:
slots = [torch.zeros_like(q) for q in queries]
indexeses = []
max_lens = []
queries_rebatches = []
reference_points_rebatches = []
for tpv_idx, tpv_mask in enumerate(tpv_masks):
indexes = []
for _, mask_per_img in enumerate(tpv_mask):
index_query_per_img = mask_per_img[0].sum(
-1).nonzero().squeeze(-1)
indexes.append(index_query_per_img)
max_len = max([len(each) for each in indexes])
max_lens.append(max_len)
indexeses.append(indexes)
reference_points_cam = reference_points_cams[tpv_idx]
D = reference_points_cam.size(3)
queries_rebatch = queries[tpv_idx].new_zeros(
[bs * self.num_cams, max_len, self.embed_dims])
reference_points_rebatch = reference_points_cam.new_zeros(
[bs * self.num_cams, max_len, D, 2])
for i, reference_points_per_img in enumerate(reference_points_cam):
for j in range(bs):
index_query_per_img = indexes[i]
queries_rebatch[j * self.num_cams +
i, :len(index_query_per_img)] = queries[
tpv_idx][j, index_query_per_img]
reference_points_rebatch[j * self.num_cams + i, :len(
index_query_per_img)] = reference_points_per_img[
j, index_query_per_img]
queries_rebatches.append(queries_rebatch)
reference_points_rebatches.append(reference_points_rebatch)
num_cams, l, bs, embed_dims = key.shape
key = key.permute(0, 2, 1, 3).view(self.num_cams * bs, l,
self.embed_dims)
value = value.permute(0, 2, 1, 3).view(self.num_cams * bs, l,
self.embed_dims)
queries = self.deformable_attention(
query=queries_rebatches,
key=key,
value=value,
reference_points=reference_points_rebatches,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
)
for tpv_idx, indexes in enumerate(indexeses):
for i, index_query_per_img in enumerate(indexes):
for j in range(bs):
slots[tpv_idx][j, index_query_per_img] += queries[tpv_idx][
j * self.num_cams + i, :len(index_query_per_img)]
count = tpv_masks[tpv_idx].sum(-1) > 0
count = count.permute(1, 2, 0).sum(-1)
count = torch.clamp(count, min=1.0)
slots[tpv_idx] = slots[tpv_idx] / count[..., None]
slots = torch.cat(slots, dim=1)
slots = self.output_proj(slots)
return self.dropout(slots) + inp_residual
@MODELS.register_module()
class TPVMSDeformableAttention3D(BaseModule):
"""An attention module used in tpvFormer based on Deformable-Detr.
`Deformable DETR: Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(
self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=[8, 64, 64],
num_z_anchors=[4, 32, 32],
pc_range=None,
im2col_step=64,
dropout=0.1,
batch_first=True,
norm_cfg=None,
init_cfg=None,
floor_sampling_offset=True,
tpv_h=None,
tpv_w=None,
tpv_z=None,
):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.batch_first = batch_first
self.output_proj = None
self.fp16_enabled = False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.num_z_anchors = num_z_anchors
self.base_num_points = num_points[0]
self.base_z_anchors = num_z_anchors[0]
self.points_multiplier = [
points // self.base_z_anchors for points in num_z_anchors
]
self.pc_range = pc_range
self.tpv_h, self.tpv_w, self.tpv_z = tpv_h, tpv_w, tpv_z
self.sampling_offsets = nn.ModuleList([
nn.Linear(embed_dims, num_heads * num_levels * num_points[i] * 2)
for i in range(3)
])
self.floor_sampling_offset = floor_sampling_offset
self.attention_weights = nn.ModuleList([
nn.Linear(embed_dims, num_heads * num_levels * num_points[i])
for i in range(3)
])
self.value_proj = nn.Linear(embed_dims, embed_dims)
def init_weights(self):
"""Default initialization for Parameters of Module."""
device = next(self.parameters()).device
for i in range(3):
constant_init(self.sampling_offsets[i], 0.)
thetas = torch.arange(
self.num_heads, dtype=torch.float32,
device=device) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points[i],
1)
grid_init = grid_init.reshape(self.num_heads, self.num_levels,
self.num_z_anchors[i], -1, 2)
for j in range(self.num_points[i] // self.num_z_anchors[i]):
grid_init[:, :, :, j, :] *= j + 1
self.sampling_offsets[i].bias.data = grid_init.view(-1)
constant_init(self.attention_weights[i], val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True
def get_sampling_offsets_and_attention(self, queries):
offsets = []
attns = []
for i, (query, fc, attn) in enumerate(
zip(queries, self.sampling_offsets, self.attention_weights)):
bs, l, d = query.shape
offset = fc(query).reshape(bs, l, self.num_heads, self.num_levels,
self.points_multiplier[i], -1, 2)
offset = offset.permute(0, 1, 4, 2, 3, 5, 6).flatten(1, 2)
offsets.append(offset)
attention = attn(query).reshape(bs, l, self.num_heads, -1)
attention = attention.softmax(-1)
attention = attention.view(bs, l, self.num_heads, self.num_levels,
self.points_multiplier[i], -1)
attention = attention.permute(0, 1, 4, 2, 3, 5).flatten(1, 2)
attns.append(attention)
offsets = torch.cat(offsets, dim=1)
attns = torch.cat(attns, dim=1)
return offsets, attns
def reshape_reference_points(self, reference_points):
reference_point_list = []
for i, reference_point in enumerate(reference_points):
bs, l, z_anchors, _ = reference_point.shape
reference_point = reference_point.reshape(
bs, l, self.points_multiplier[i], -1, 2)
reference_point = reference_point.flatten(1, 2)
reference_point_list.append(reference_point)
return torch.cat(reference_point_list, dim=1)
def reshape_output(self, output, lens):
bs, _, d = output.shape
outputs = torch.split(
output, [
lens[0] * self.points_multiplier[0], lens[1] *
self.points_multiplier[1], lens[2] * self.points_multiplier[2]
],
dim=1)
outputs = [
o.reshape(bs, -1, self.points_multiplier[i], d).sum(dim=2)
for i, o in enumerate(outputs)
]
return outputs
def forward(self,
query,
key=None,
value=None,
identity=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
( bs, num_query, embed_dims).
key (Tensor): The key tensor with shape
`(bs, num_key, embed_dims)`.
value (Tensor): The value tensor with shape
`(bs, num_key, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [bs, num_query, embed_dims].
"""
if value is None:
value = query
if identity is None:
identity = query
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = [q.permute(1, 0, 2) for q in query]
value = value.permute(1, 0, 2)
# bs, num_query, _ = query.shape
query_lens = [q.shape[1] for q in query]
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets, attention_weights = \
self.get_sampling_offsets_and_attention(query)
reference_points = self.reshape_reference_points(reference_points)
if reference_points.shape[-1] == 2:
"""For each tpv query, it owns `num_Z_anchors` in 3D space that
having different heights. After projecting, each tpv query has
`num_Z_anchors` reference points in each 2D image. For each
referent point, we sample `num_points` sampling points.
For `num_Z_anchors` reference points,
it has overall `num_points * num_Z_anchors` sampling points.
"""
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
bs, num_query, num_Z_anchors, xy = reference_points.shape
reference_points = reference_points[:, :, None, None, :, None, :]
sampling_offsets = sampling_offsets / \
offset_normalizer[None, None, None, :, None, :]
bs, num_query, num_heads, num_levels, num_all_points, xy = \
sampling_offsets.shape
sampling_offsets = sampling_offsets.view(
bs, num_query, num_heads, num_levels, num_Z_anchors,
num_all_points // num_Z_anchors, xy)
sampling_locations = reference_points + sampling_offsets
bs, num_query, num_heads, num_levels, num_points, num_Z_anchors, \
xy = sampling_locations.shape
assert num_all_points == num_points * num_Z_anchors
sampling_locations = sampling_locations.view(
bs, num_query, num_heads, num_levels, num_all_points, xy)
if self.floor_sampling_offset:
sampling_locations = sampling_locations - torch.floor(
sampling_locations)
elif reference_points.shape[-1] == 4:
assert False
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available() and value.is_cuda:
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
output = self.reshape_output(output, query_lens)
if not self.batch_first:
output = [o.permute(1, 0, 2) for o in output]
return output
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Optional, Union
import mmcv
import numpy as np
from mmcv.transforms.base import BaseTransform
from mmengine.fileio import get
from mmdet3d.datasets.transforms import LoadMultiViewImageFromFiles
from mmdet3d.registry import TRANSFORMS
Number = Union[int, float]
@TRANSFORMS.register_module()
class BEVLoadMultiViewImageFromFiles(LoadMultiViewImageFromFiles):
"""Load multi channel images from a list of separate channel files.
``BEVLoadMultiViewImageFromFiles`` adds the following keys for the
convenience of view transforms in the forward:
- 'cam2lidar'
- 'lidar2img'
Args:
to_float32 (bool): Whether to convert the img to float32.
Defaults to False.
color_type (str): Color type of the file. Defaults to 'unchanged'.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
num_views (int): Number of view in a frame. Defaults to 5.
num_ref_frames (int): Number of frame in loading. Defaults to -1.
test_mode (bool): Whether is test mode in loading. Defaults to False.
set_default_scale (bool): Whether to set default scale.
Defaults to True.
"""
def transform(self, results: dict) -> Optional[dict]:
"""Call function to load multi-view image from files.
Args:
results (dict): Result dict containing multi-view image filenames.
Returns:
dict: The result dict containing the multi-view image data.
Added keys and values are described below.
- filename (str): Multi-view image filenames.
- img (np.ndarray): Multi-view image arrays.
- img_shape (tuple[int]): Shape of multi-view image arrays.
- ori_shape (tuple[int]): Shape of original image arrays.
- pad_shape (tuple[int]): Shape of padded image arrays.
- scale_factor (float): Scale factor.
- img_norm_cfg (dict): Normalization configuration of images.
"""
filename, cam2img, lidar2cam, lidar2img = [], [], [], []
for _, cam_item in results['images'].items():
filename.append(cam_item['img_path'])
lidar2cam.append(cam_item['lidar2cam'])
lidar2cam_array = np.array(cam_item['lidar2cam'])
cam2img_array = np.eye(4).astype(np.float64)
cam2img_array[:3, :3] = np.array(cam_item['cam2img'])
cam2img.append(cam2img_array)
lidar2img.append(cam2img_array @ lidar2cam_array)
results['img_path'] = filename
results['cam2img'] = np.stack(cam2img, axis=0)
results['lidar2cam'] = np.stack(lidar2cam, axis=0)
results['lidar2img'] = np.stack(lidar2img, axis=0)
results['ori_cam2img'] = copy.deepcopy(results['cam2img'])
# img is of shape (h, w, c, num_views)
# h and w can be different for different views
img_bytes = [
get(name, backend_args=self.backend_args) for name in filename
]
# gbr follow tpvformer
imgs = [
mmcv.imfrombytes(img_byte, flag=self.color_type)
for img_byte in img_bytes
]
# handle the image with different shape
img_shapes = np.stack([img.shape for img in imgs], axis=0)
img_shape_max = np.max(img_shapes, axis=0)
img_shape_min = np.min(img_shapes, axis=0)
assert img_shape_min[-1] == img_shape_max[-1]
if not np.all(img_shape_max == img_shape_min):
pad_shape = img_shape_max[:2]
else:
pad_shape = None
if pad_shape is not None:
imgs = [
mmcv.impad(img, shape=pad_shape, pad_val=0) for img in imgs
]
img = np.stack(imgs, axis=-1)
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = filename
# unravel to list, see `DefaultFormatBundle` in formating.py
# which will transpose each image separately and then stack into array
results['img'] = [img[..., i] for i in range(img.shape[-1])]
results['img_shape'] = img.shape[:2]
results['ori_shape'] = img.shape[:2]
# Set initial values for default meta_keys
results['pad_shape'] = img.shape[:2]
if self.set_default_scale:
results['scale_factor'] = 1.0
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
results['img_norm_cfg'] = dict(
mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32),
to_rgb=False)
results['num_views'] = self.num_views
results['num_ref_frames'] = self.num_ref_frames
return results
@TRANSFORMS.register_module()
class SegLabelMapping(BaseTransform):
"""Map original semantic class to valid category ids.
Required Keys:
- seg_label_mapping (np.ndarray)
- pts_semantic_mask (np.ndarray)
Added Keys:
- points (np.float32)
Map valid classes as 0~len(valid_cat_ids)-1 and
others as len(valid_cat_ids).
"""
def transform(self, results: dict) -> dict:
"""Call function to map original semantic class to valid category ids.
Args:
results (dict): Result dict containing point semantic masks.
Returns:
dict: The result dict containing the mapped category ids.
Updated key and value are described below.
- pts_semantic_mask (np.ndarray): Mapped semantic masks.
"""
assert 'pts_semantic_mask' in results
pts_semantic_mask = results['pts_semantic_mask']
assert 'seg_label_mapping' in results
label_mapping = results['seg_label_mapping']
converted_pts_sem_mask = np.vectorize(
label_mapping.__getitem__, otypes=[np.uint8])(
pts_semantic_mask)
results['pts_semantic_mask'] = converted_pts_sem_mask
# 'eval_ann_info' will be passed to evaluator
if 'eval_ann_info' in results:
assert 'pts_semantic_mask' in results['eval_ann_info']
results['eval_ann_info']['pts_semantic_mask'] = \
converted_pts_sem_mask
return results
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
return repr_str
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Callable, List, Union
from mmengine.dataset import BaseDataset
from mmdet3d.registry import DATASETS
@DATASETS.register_module()
class NuScenesSegDataset(BaseDataset):
r"""NuScenes Dataset.
This class serves as the API for experiments on the NuScenes Dataset.
Please refer to `NuScenes Dataset <https://www.nuscenes.org/download>`_
for data downloading.
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict]): Pipeline used for data processing.
Defaults to [].
test_mode (bool): Store `True` when building test or val dataset.
"""
METAINFO = {
'classes':
('noise', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',
'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',
'driveable_surface', 'other_flat', 'sidewalk', 'terrain', 'manmade',
'vegetation'),
'ignore_index':
0,
'label_mapping':
dict([(1, 0), (5, 0), (7, 0), (8, 0), (10, 0), (11, 0), (13, 0),
(19, 0), (20, 0), (0, 0), (29, 0), (31, 0), (9, 1), (14, 2),
(15, 3), (16, 3), (17, 4), (18, 5), (21, 6), (2, 7), (3, 7),
(4, 7), (6, 7), (12, 8), (22, 9), (23, 10), (24, 11), (25, 12),
(26, 13), (27, 14), (28, 15), (30, 16)]),
'palette': [
[0, 0, 0], # noise
[255, 120, 50], # barrier orange
[255, 192, 203], # bicycle pink
[255, 255, 0], # bus yellow
[0, 150, 245], # car blue
[0, 255, 255], # construction_vehicle cyan
[255, 127, 0], # motorcycle dark orange
[255, 0, 0], # pedestrian red
[255, 240, 150], # traffic_cone light yellow
[135, 60, 0], # trailer brown
[160, 32, 240], # truck purple
[255, 0, 255], # driveable_surface dark pink
[139, 137, 137], # other_flat dark red
[75, 0, 75], # sidewalk dard purple
[150, 240, 80], # terrain light green
[230, 230, 250], # manmade white
[0, 175, 0], # vegetation green
]
}
def __init__(self,
data_root: str,
ann_file: str,
pipeline: List[Union[dict, Callable]] = [],
test_mode: bool = False,
**kwargs) -> None:
metainfo = dict(label2cat={
i: cat_name
for i, cat_name in enumerate(self.METAINFO['classes'])
})
super().__init__(
ann_file=ann_file,
data_root=data_root,
metainfo=metainfo,
pipeline=pipeline,
test_mode=test_mode,
**kwargs)
def parse_data_info(self, info: dict) -> Union[List[dict], dict]:
"""Process the raw data info.
The only difference with it in `Det3DDataset`
is the specific process for `plane`.
Args:
info (dict): Raw info dict.
Returns:
List[dict] or dict: Has `ann_info` in training stage. And
all path has been converted to absolute path.
"""
data_list = []
info['lidar_points']['lidar_path'] = \
osp.join(
self.data_prefix.get('pts', ''),
info['lidar_points']['lidar_path'])
for cam_id, img_info in info['images'].items():
if 'img_path' in img_info:
if cam_id in self.data_prefix:
cam_prefix = self.data_prefix[cam_id]
else:
cam_prefix = self.data_prefix.get('img', '')
img_info['img_path'] = osp.join(cam_prefix,
img_info['img_path'])
if 'pts_semantic_mask_path' in info:
info['pts_semantic_mask_path'] = \
osp.join(self.data_prefix.get('pts_semantic_mask', ''),
info['pts_semantic_mask_path'])
# only be used in `PointSegClassMapping` in pipeline
# to map original semantic class to valid category ids.
info['seg_label_mapping'] = self.metainfo['label_mapping']
# 'eval_ann_info' will be updated in loading transforms
if self.test_mode:
info['eval_ann_info'] = dict()
data_list.append(info)
return data_list
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmdet3d.registry import MODELS
@MODELS.register_module()
class TPVFormerPositionalEncoding(BaseModule):
def __init__(self,
num_feats,
h,
w,
z,
init_cfg=dict(type='Uniform', layer='Embedding')):
super().__init__(init_cfg)
if not isinstance(num_feats, list):
num_feats = [num_feats] * 3
self.h_embed = nn.Embedding(h, num_feats[0])
self.w_embed = nn.Embedding(w, num_feats[1])
self.z_embed = nn.Embedding(z, num_feats[2])
self.num_feats = num_feats
self.h, self.w, self.z = h, w, z
def forward(self, bs, device, ignore_axis='z'):
if ignore_axis == 'h':
h_embed = torch.zeros(
1, 1, self.num_feats[0],
device=device).repeat(self.w, self.z, 1) # w, z, d
w_embed = self.w_embed(torch.arange(self.w, device=device))
w_embed = w_embed.reshape(self.w, 1, -1).repeat(1, self.z, 1)
z_embed = self.z_embed(torch.arange(self.z, device=device))
z_embed = z_embed.reshape(1, self.z, -1).repeat(self.w, 1, 1)
elif ignore_axis == 'w':
h_embed = self.h_embed(torch.arange(self.h, device=device))
h_embed = h_embed.reshape(1, self.h, -1).repeat(self.z, 1, 1)
w_embed = torch.zeros(
1, 1, self.num_feats[1],
device=device).repeat(self.z, self.h, 1)
z_embed = self.z_embed(torch.arange(self.z, device=device))
z_embed = z_embed.reshape(self.z, 1, -1).repeat(1, self.h, 1)
elif ignore_axis == 'z':
h_embed = self.h_embed(torch.arange(self.h, device=device))
h_embed = h_embed.reshape(self.h, 1, -1).repeat(1, self.w, 1)
w_embed = self.w_embed(torch.arange(self.w, device=device))
w_embed = w_embed.reshape(1, self.w, -1).repeat(self.h, 1, 1)
z_embed = torch.zeros(
1, 1, self.num_feats[2],
device=device).repeat(self.h, self.w, 1)
pos = torch.cat((h_embed, w_embed, z_embed),
dim=-1).flatten(0, 1).unsqueeze(0).repeat(bs, 1, 1)
return pos
from typing import Optional, Union
from torch import nn
from mmdet3d.models import Base3DSegmentor
from mmdet3d.registry import MODELS
from mmdet3d.structures.det3d_data_sample import SampleList
@MODELS.register_module()
class TPVFormer(Base3DSegmentor):
def __init__(self,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
backbone=None,
neck=None,
encoder=None,
decode_head=None):
super().__init__(data_preprocessor=data_preprocessor)
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = MODELS.build(neck)
self.encoder = MODELS.build(encoder)
self.decode_head = MODELS.build(decode_head)
def extract_feat(self, img):
"""Extract features of images."""
B, N, C, H, W = img.size()
img = img.view(B * N, C, H, W)
img_feats = self.backbone(img)
if hasattr(self, 'neck'):
img_feats = self.neck(img_feats)
img_feats_reshaped = []
for img_feat in img_feats:
_, C, H, W = img_feat.size()
img_feats_reshaped.append(img_feat.view(B, N, C, H, W))
return img_feats_reshaped
def _forward(self, batch_inputs, batch_data_samples):
"""Forward training function."""
img_feats = self.extract_feat(batch_inputs['imgs'])
outs = self.encoder(img_feats, batch_data_samples)
outs = self.decode_head(outs, batch_inputs['voxels']['coors'])
return outs
def loss(self, batch_inputs: dict,
batch_data_samples: SampleList) -> SampleList:
img_feats = self.extract_feat(batch_inputs['imgs'])
queries = self.encoder(img_feats, batch_data_samples)
losses = self.decode_head.loss(queries, batch_data_samples)
return losses
def predict(self, batch_inputs: dict,
batch_data_samples: SampleList) -> SampleList:
"""Forward predict function."""
img_feats = self.extract_feat(batch_inputs['imgs'])
tpv_queries = self.encoder(img_feats, batch_data_samples)
seg_logits = self.decode_head.predict(tpv_queries, batch_data_samples)
seg_preds = [seg_logit.argmax(dim=1) for seg_logit in seg_logits]
return self.postprocess_result(seg_preds, batch_data_samples)
def aug_test(self, batch_inputs, batch_data_samples):
pass
def encode_decode(self, batch_inputs: dict,
batch_data_samples: SampleList) -> SampleList:
pass
import numpy as np
import torch
from mmcv.cnn.bricks.transformer import TransformerLayerSequence
from mmengine.registry import MODELS
from torch import nn
from torch.nn.init import normal_
from .cross_view_hybrid_attention import TPVCrossViewHybridAttention
from .image_cross_attention import TPVMSDeformableAttention3D
@MODELS.register_module()
class TPVFormerEncoder(TransformerLayerSequence):
def __init__(self,
tpv_h=200,
tpv_w=200,
tpv_z=16,
pc_range=[-51.2, -51.2, -5, 51.2, 51.2, 3],
num_feature_levels=4,
num_cams=6,
embed_dims=256,
num_points_in_pillar=[4, 32, 32],
num_points_in_pillar_cross_view=[32, 32, 32],
num_layers=5,
transformerlayers=None,
positional_encoding=None,
return_intermediate=False):
super().__init__(transformerlayers, num_layers)
self.tpv_h = tpv_h
self.tpv_w = tpv_w
self.tpv_z = tpv_z
self.pc_range = pc_range
self.real_w = pc_range[3] - pc_range[0]
self.real_h = pc_range[4] - pc_range[1]
self.real_z = pc_range[5] - pc_range[2]
self.level_embeds = nn.Parameter(
torch.Tensor(num_feature_levels, embed_dims))
self.cams_embeds = nn.Parameter(torch.Tensor(num_cams, embed_dims))
self.tpv_embedding_hw = nn.Embedding(tpv_h * tpv_w, embed_dims)
self.tpv_embedding_zh = nn.Embedding(tpv_z * tpv_h, embed_dims)
self.tpv_embedding_wz = nn.Embedding(tpv_w * tpv_z, embed_dims)
ref_3d_hw = self.get_reference_points(tpv_h, tpv_w, self.real_z,
num_points_in_pillar[0])
ref_3d_zh = self.get_reference_points(tpv_z, tpv_h, self.real_w,
num_points_in_pillar[1])
ref_3d_zh = ref_3d_zh.permute(3, 0, 1, 2)[[2, 0, 1]] # change to x,y,z
ref_3d_zh = ref_3d_zh.permute(1, 2, 3, 0)
ref_3d_wz = self.get_reference_points(tpv_w, tpv_z, self.real_h,
num_points_in_pillar[2])
ref_3d_wz = ref_3d_wz.permute(3, 0, 1, 2)[[1, 2, 0]] # change to x,y,z
ref_3d_wz = ref_3d_wz.permute(1, 2, 3, 0)
self.register_buffer('ref_3d_hw', ref_3d_hw)
self.register_buffer('ref_3d_zh', ref_3d_zh)
self.register_buffer('ref_3d_wz', ref_3d_wz)
cross_view_ref_points = self.get_cross_view_ref_points(
tpv_h, tpv_w, tpv_z, num_points_in_pillar_cross_view)
self.register_buffer('cross_view_ref_points', cross_view_ref_points)
# positional encoding
self.positional_encoding = MODELS.build(positional_encoding)
self.return_intermediate = return_intermediate
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, TPVMSDeformableAttention3D) or isinstance(
m, TPVCrossViewHybridAttention):
m.init_weights()
normal_(self.level_embeds)
normal_(self.cams_embeds)
@staticmethod
def get_cross_view_ref_points(tpv_h, tpv_w, tpv_z, num_points_in_pillar):
# ref points generating target: (#query)hw+zh+wz, (#level)3, #p, 2
# generate points for hw and level 1
h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h
w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w
h_ranges = h_ranges.unsqueeze(-1).expand(-1, tpv_w).flatten()
w_ranges = w_ranges.unsqueeze(0).expand(tpv_h, -1).flatten()
hw_hw = torch.stack([w_ranges, h_ranges], dim=-1) # hw, 2
hw_hw = hw_hw.unsqueeze(1).expand(-1, num_points_in_pillar[2],
-1) # hw, #p, 2
# generate points for hw and level 2
z_ranges = torch.linspace(0.5, tpv_z - 0.5,
num_points_in_pillar[2]) / tpv_z # #p
z_ranges = z_ranges.unsqueeze(0).expand(tpv_h * tpv_w, -1) # hw, #p
h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h
h_ranges = h_ranges.reshape(-1, 1, 1).expand(
-1, tpv_w, num_points_in_pillar[2]).flatten(0, 1)
hw_zh = torch.stack([h_ranges, z_ranges], dim=-1) # hw, #p, 2
# generate points for hw and level 3
z_ranges = torch.linspace(0.5, tpv_z - 0.5,
num_points_in_pillar[2]) / tpv_z # #p
z_ranges = z_ranges.unsqueeze(0).expand(tpv_h * tpv_w, -1) # hw, #p
w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w
w_ranges = w_ranges.reshape(1, -1, 1).expand(
tpv_h, -1, num_points_in_pillar[2]).flatten(0, 1)
hw_wz = torch.stack([z_ranges, w_ranges], dim=-1) # hw, #p, 2
# generate points for zh and level 1
w_ranges = torch.linspace(0.5, tpv_w - 0.5,
num_points_in_pillar[1]) / tpv_w
w_ranges = w_ranges.unsqueeze(0).expand(tpv_z * tpv_h, -1)
h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h
h_ranges = h_ranges.reshape(1, -1, 1).expand(
tpv_z, -1, num_points_in_pillar[1]).flatten(0, 1)
zh_hw = torch.stack([w_ranges, h_ranges], dim=-1)
# generate points for zh and level 2
z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z
z_ranges = z_ranges.reshape(-1, 1, 1).expand(
-1, tpv_h, num_points_in_pillar[1]).flatten(0, 1)
h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h
h_ranges = h_ranges.reshape(1, -1, 1).expand(
tpv_z, -1, num_points_in_pillar[1]).flatten(0, 1)
zh_zh = torch.stack([h_ranges, z_ranges], dim=-1) # zh, #p, 2
# generate points for zh and level 3
w_ranges = torch.linspace(0.5, tpv_w - 0.5,
num_points_in_pillar[1]) / tpv_w
w_ranges = w_ranges.unsqueeze(0).expand(tpv_z * tpv_h, -1)
z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z
z_ranges = z_ranges.reshape(-1, 1, 1).expand(
-1, tpv_h, num_points_in_pillar[1]).flatten(0, 1)
zh_wz = torch.stack([z_ranges, w_ranges], dim=-1)
# generate points for wz and level 1
h_ranges = torch.linspace(0.5, tpv_h - 0.5,
num_points_in_pillar[0]) / tpv_h
h_ranges = h_ranges.unsqueeze(0).expand(tpv_w * tpv_z, -1)
w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w
w_ranges = w_ranges.reshape(-1, 1, 1).expand(
-1, tpv_z, num_points_in_pillar[0]).flatten(0, 1)
wz_hw = torch.stack([w_ranges, h_ranges], dim=-1)
# generate points for wz and level 2
h_ranges = torch.linspace(0.5, tpv_h - 0.5,
num_points_in_pillar[0]) / tpv_h
h_ranges = h_ranges.unsqueeze(0).expand(tpv_w * tpv_z, -1)
z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z
z_ranges = z_ranges.reshape(1, -1, 1).expand(
tpv_w, -1, num_points_in_pillar[0]).flatten(0, 1)
wz_zh = torch.stack([h_ranges, z_ranges], dim=-1)
# generate points for wz and level 3
w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w
w_ranges = w_ranges.reshape(-1, 1, 1).expand(
-1, tpv_z, num_points_in_pillar[0]).flatten(0, 1)
z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z
z_ranges = z_ranges.reshape(1, -1, 1).expand(
tpv_w, -1, num_points_in_pillar[0]).flatten(0, 1)
wz_wz = torch.stack([z_ranges, w_ranges], dim=-1)
reference_points = torch.cat([
torch.stack([hw_hw, hw_zh, hw_wz], dim=1),
torch.stack([zh_hw, zh_zh, zh_wz], dim=1),
torch.stack([wz_hw, wz_zh, wz_wz], dim=1)
],
dim=0) # hw+zh+wz, 3, #p, 2
return reference_points
@staticmethod
def get_reference_points(H,
W,
Z=8,
num_points_in_pillar=4,
dim='3d',
bs=1,
device='cuda',
dtype=torch.float):
"""Get the reference points used in SCA and TSA.
Args:
H, W: spatial shape of tpv.
Z: height of pillar.
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
# reference points in 3D space, used in spatial cross-attention (SCA)
zs = torch.linspace(
0.5, Z - 0.5, num_points_in_pillar,
dtype=dtype, device=device).view(-1, 1, 1).expand(
num_points_in_pillar, H, W) / Z
xs = torch.linspace(
0.5, W - 0.5, W, dtype=dtype, device=device).view(1, 1, -1).expand(
num_points_in_pillar, H, W) / W
ys = torch.linspace(
0.5, H - 0.5, H, dtype=dtype, device=device).view(1, -1, 1).expand(
num_points_in_pillar, H, W) / H
ref_3d = torch.stack((xs, ys, zs), -1)
ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
ref_3d = ref_3d[None].repeat(bs, 1, 1, 1)
return ref_3d
def point_sampling(self, reference_points, pc_range, batch_data_smaples):
lidar2img = []
for data_sample in batch_data_smaples:
lidar2img.append(data_sample.lidar2img)
lidar2img = np.asarray(lidar2img)
lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4)
reference_points = reference_points.clone()
reference_points[..., 0:1] = reference_points[..., 0:1] * \
(pc_range[3] - pc_range[0]) + pc_range[0]
reference_points[..., 1:2] = reference_points[..., 1:2] * \
(pc_range[4] - pc_range[1]) + pc_range[1]
reference_points[..., 2:3] = reference_points[..., 2:3] * \
(pc_range[5] - pc_range[2]) + pc_range[2]
reference_points = torch.cat(
(reference_points, torch.ones_like(reference_points[..., :1])), -1)
reference_points = reference_points.permute(1, 0, 2, 3)
D, B, num_query = reference_points.size()[:3]
num_cam = lidar2img.size(1)
reference_points = reference_points.view(D, B, 1, num_query, 4).repeat(
1, 1, num_cam, 1, 1).unsqueeze(-1)
lidar2img = lidar2img.view(1, B, num_cam, 1, 4,
4).repeat(D, 1, 1, num_query, 1, 1)
reference_points_cam = torch.matmul(
lidar2img.to(torch.float32),
reference_points.to(torch.float32)).squeeze(-1)
eps = 1e-5
tpv_mask = (reference_points_cam[..., 2:3] > eps)
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
reference_points_cam[..., 2:3],
torch.ones_like(reference_points_cam[..., 2:3]) * eps)
reference_points_cam[..., 0] /= data_sample.batch_input_shape[1]
reference_points_cam[..., 1] /= data_sample.batch_input_shape[0]
tpv_mask = (
tpv_mask & (reference_points_cam[..., 1:2] > 0.0)
& (reference_points_cam[..., 1:2] < 1.0)
& (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 0:1] > 0.0))
tpv_mask = torch.nan_to_num(tpv_mask)
reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
tpv_mask = tpv_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
return reference_points_cam, tpv_mask
def forward(self, mlvl_feats, batch_data_samples):
"""Forward function.
Args:
mlvl_feats (tuple[Tensor]): Features from the upstream
network, each is a 5D-tensor with shape
(B, N, C, H, W).
"""
bs = mlvl_feats[0].shape[0]
dtype = mlvl_feats[0].dtype
device = mlvl_feats[0].device
# tpv queries and pos embeds
tpv_queries_hw = self.tpv_embedding_hw.weight.to(dtype)
tpv_queries_zh = self.tpv_embedding_zh.weight.to(dtype)
tpv_queries_wz = self.tpv_embedding_wz.weight.to(dtype)
tpv_queries_hw = tpv_queries_hw.unsqueeze(0).repeat(bs, 1, 1)
tpv_queries_zh = tpv_queries_zh.unsqueeze(0).repeat(bs, 1, 1)
tpv_queries_wz = tpv_queries_wz.unsqueeze(0).repeat(bs, 1, 1)
tpv_query = [tpv_queries_hw, tpv_queries_zh, tpv_queries_wz]
tpv_pos_hw = self.positional_encoding(bs, device, 'z')
tpv_pos_zh = self.positional_encoding(bs, device, 'w')
tpv_pos_wz = self.positional_encoding(bs, device, 'h')
tpv_pos = [tpv_pos_hw, tpv_pos_zh, tpv_pos_wz]
# flatten image features of different scales
feat_flatten = []
spatial_shapes = []
for lvl, feat in enumerate(mlvl_feats):
bs, num_cam, c, h, w = feat.shape
spatial_shape = (h, w)
feat = feat.flatten(3).permute(1, 0, 3, 2) # num_cam, bs, hw, c
feat = feat + self.cams_embeds[:, None, None, :].to(dtype)
feat = feat + self.level_embeds[None, None,
lvl:lvl + 1, :].to(dtype)
spatial_shapes.append(spatial_shape)
feat_flatten.append(feat)
feat_flatten = torch.cat(feat_flatten, 2) # num_cam, bs, hw++, c
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=device)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
feat_flatten = feat_flatten.permute(
0, 2, 1, 3) # (num_cam, H*W, bs, embed_dims)
reference_points_cams, tpv_masks = [], []
ref_3ds = [self.ref_3d_hw, self.ref_3d_zh, self.ref_3d_wz]
for ref_3d in ref_3ds:
reference_points_cam, tpv_mask = self.point_sampling(
ref_3d, self.pc_range,
batch_data_samples) # num_cam, bs, hw++, #p, 2
reference_points_cams.append(reference_points_cam)
tpv_masks.append(tpv_mask)
ref_cross_view = self.cross_view_ref_points.clone().unsqueeze(
0).expand(bs, -1, -1, -1, -1)
intermediate = []
for layer in self.layers:
output = layer(
tpv_query,
feat_flatten,
feat_flatten,
tpv_pos=tpv_pos,
ref_2d=ref_cross_view,
tpv_h=self.tpv_h,
tpv_w=self.tpv_w,
tpv_z=self.tpv_z,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
reference_points_cams=reference_points_cams,
tpv_masks=tpv_masks)
tpv_query = output
if self.return_intermediate:
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmdet3d.registry import MODELS
@MODELS.register_module()
class TPVFormerDecoder(BaseModule):
def __init__(self,
tpv_h,
tpv_w,
tpv_z,
num_classes=20,
in_dims=64,
hidden_dims=128,
out_dims=None,
scale_h=2,
scale_w=2,
scale_z=2,
ignore_index=0,
loss_lovasz=None,
loss_ce=None,
lovasz_input='points',
ce_input='voxel'):
super().__init__()
self.tpv_h = tpv_h
self.tpv_w = tpv_w
self.tpv_z = tpv_z
self.scale_h = scale_h
self.scale_w = scale_w
self.scale_z = scale_z
out_dims = in_dims if out_dims is None else out_dims
self.in_dims = in_dims
self.decoder = nn.Sequential(
nn.Linear(in_dims, hidden_dims), nn.Softplus(),
nn.Linear(hidden_dims, out_dims))
self.classifier = nn.Linear(out_dims, num_classes)
self.loss_lovasz = MODELS.build(loss_lovasz)
self.loss_ce = MODELS.build(loss_ce)
self.ignore_index = ignore_index
self.lovasz_input = lovasz_input
self.ce_input = ce_input
def forward(self, tpv_list, points=None):
"""
tpv_list[0]: bs, h*w, c
tpv_list[1]: bs, z*h, c
tpv_list[2]: bs, w*z, c
"""
tpv_hw, tpv_zh, tpv_wz = tpv_list[0], tpv_list[1], tpv_list[2]
bs, _, c = tpv_hw.shape
tpv_hw = tpv_hw.permute(0, 2, 1).reshape(bs, c, self.tpv_h, self.tpv_w)
tpv_zh = tpv_zh.permute(0, 2, 1).reshape(bs, c, self.tpv_z, self.tpv_h)
tpv_wz = tpv_wz.permute(0, 2, 1).reshape(bs, c, self.tpv_w, self.tpv_z)
if self.scale_h != 1 or self.scale_w != 1:
tpv_hw = F.interpolate(
tpv_hw,
size=(self.tpv_h * self.scale_h, self.tpv_w * self.scale_w),
mode='bilinear')
if self.scale_z != 1 or self.scale_h != 1:
tpv_zh = F.interpolate(
tpv_zh,
size=(self.tpv_z * self.scale_z, self.tpv_h * self.scale_h),
mode='bilinear')
if self.scale_w != 1 or self.scale_z != 1:
tpv_wz = F.interpolate(
tpv_wz,
size=(self.tpv_w * self.scale_w, self.tpv_z * self.scale_z),
mode='bilinear')
if points is not None:
# points: bs, n, 3
_, n, _ = points.shape
points = points.reshape(bs, 1, n, 3).float()
points[...,
0] = points[..., 0] / (self.tpv_w * self.scale_w) * 2 - 1
points[...,
1] = points[..., 1] / (self.tpv_h * self.scale_h) * 2 - 1
points[...,
2] = points[..., 2] / (self.tpv_z * self.scale_z) * 2 - 1
sample_loc = points[:, :, :, [0, 1]]
tpv_hw_pts = F.grid_sample(tpv_hw,
sample_loc).squeeze(2) # bs, c, n
sample_loc = points[:, :, :, [1, 2]]
tpv_zh_pts = F.grid_sample(tpv_zh, sample_loc).squeeze(2)
sample_loc = points[:, :, :, [2, 0]]
tpv_wz_pts = F.grid_sample(tpv_wz, sample_loc).squeeze(2)
tpv_hw_vox = tpv_hw.unsqueeze(-1).permute(0, 1, 3, 2, 4).expand(
-1, -1, -1, -1, self.scale_z * self.tpv_z)
tpv_zh_vox = tpv_zh.unsqueeze(-1).permute(0, 1, 4, 3, 2).expand(
-1, -1, self.scale_w * self.tpv_w, -1, -1)
tpv_wz_vox = tpv_wz.unsqueeze(-1).permute(0, 1, 2, 4, 3).expand(
-1, -1, -1, self.scale_h * self.tpv_h, -1)
fused_vox = (tpv_hw_vox + tpv_zh_vox + tpv_wz_vox).flatten(2)
fused_pts = tpv_hw_pts + tpv_zh_pts + tpv_wz_pts
fused = torch.cat([fused_vox, fused_pts], dim=-1) # bs, c, whz+n
fused = fused.permute(0, 2, 1)
if self.use_checkpoint:
fused = torch.utils.checkpoint.checkpoint(self.decoder, fused)
logits = torch.utils.checkpoint.checkpoint(
self.classifier, fused)
else:
fused = self.decoder(fused)
logits = self.classifier(fused)
logits = logits.permute(0, 2, 1)
logits_vox = logits[:, :, :(-n)].reshape(bs, self.classes,
self.scale_w * self.tpv_w,
self.scale_h * self.tpv_h,
self.scale_z * self.tpv_z)
logits_pts = logits[:, :, (-n):].reshape(bs, self.classes, n, 1, 1)
return logits_vox, logits_pts
else:
tpv_hw = tpv_hw.unsqueeze(-1).permute(0, 1, 3, 2, 4).expand(
-1, -1, -1, -1, self.scale_z * self.tpv_z)
tpv_zh = tpv_zh.unsqueeze(-1).permute(0, 1, 4, 3, 2).expand(
-1, -1, self.scale_w * self.tpv_w, -1, -1)
tpv_wz = tpv_wz.unsqueeze(-1).permute(0, 1, 2, 4, 3).expand(
-1, -1, -1, self.scale_h * self.tpv_h, -1)
fused = tpv_hw + tpv_zh + tpv_wz
fused = fused.permute(0, 2, 3, 4, 1)
if self.use_checkpoint:
fused = torch.utils.checkpoint.checkpoint(self.decoder, fused)
logits = torch.utils.checkpoint.checkpoint(
self.classifier, fused)
else:
fused = self.decoder(fused)
logits = self.classifier(fused)
logits = logits.permute(0, 4, 1, 2, 3)
return logits
def predict(self, tpv_list, batch_data_samples):
"""
tpv_list[0]: bs, h*w, c
tpv_list[1]: bs, z*h, c
tpv_list[2]: bs, w*z, c
"""
tpv_hw, tpv_zh, tpv_wz = tpv_list
bs, _, c = tpv_hw.shape
tpv_hw = tpv_hw.permute(0, 2, 1).reshape(bs, c, self.tpv_h, self.tpv_w)
tpv_zh = tpv_zh.permute(0, 2, 1).reshape(bs, c, self.tpv_z, self.tpv_h)
tpv_wz = tpv_wz.permute(0, 2, 1).reshape(bs, c, self.tpv_w, self.tpv_z)
if self.scale_h != 1 or self.scale_w != 1:
tpv_hw = F.interpolate(
tpv_hw,
size=(self.tpv_h * self.scale_h, self.tpv_w * self.scale_w),
mode='bilinear')
if self.scale_z != 1 or self.scale_h != 1:
tpv_zh = F.interpolate(
tpv_zh,
size=(self.tpv_z * self.scale_z, self.tpv_h * self.scale_h),
mode='bilinear')
if self.scale_w != 1 or self.scale_z != 1:
tpv_wz = F.interpolate(
tpv_wz,
size=(self.tpv_w * self.scale_w, self.tpv_z * self.scale_z),
mode='bilinear')
logits = []
for i, data_sample in enumerate(batch_data_samples):
point_coors = data_sample.point_coors.reshape(1, 1, -1, 3).float()
point_coors[
...,
0] = point_coors[..., 0] / (self.tpv_w * self.scale_w) * 2 - 1
point_coors[
...,
1] = point_coors[..., 1] / (self.tpv_h * self.scale_h) * 2 - 1
point_coors[
...,
2] = point_coors[..., 2] / (self.tpv_z * self.scale_z) * 2 - 1
sample_loc = point_coors[..., [0, 1]]
tpv_hw_pts = F.grid_sample(
tpv_hw[i:i + 1], sample_loc, align_corners=False)
sample_loc = point_coors[..., [1, 2]]
tpv_zh_pts = F.grid_sample(
tpv_zh[i:i + 1], sample_loc, align_corners=False)
sample_loc = point_coors[..., [2, 0]]
tpv_wz_pts = F.grid_sample(
tpv_wz[i:i + 1], sample_loc, align_corners=False)
fused_pts = tpv_hw_pts + tpv_zh_pts + tpv_wz_pts
fused_pts = fused_pts.squeeze(0).squeeze(1).transpose(0, 1)
fused_pts = self.decoder(fused_pts)
logit = self.classifier(fused_pts)
logits.append(logit)
return logits
def loss(self, tpv_list, batch_data_samples):
tpv_hw, tpv_zh, tpv_wz = tpv_list
bs, _, c = tpv_hw.shape
tpv_hw = tpv_hw.permute(0, 2, 1).reshape(bs, c, self.tpv_h, self.tpv_w)
tpv_zh = tpv_zh.permute(0, 2, 1).reshape(bs, c, self.tpv_z, self.tpv_h)
tpv_wz = tpv_wz.permute(0, 2, 1).reshape(bs, c, self.tpv_w, self.tpv_z)
if self.scale_h != 1 or self.scale_w != 1:
tpv_hw = F.interpolate(
tpv_hw,
size=(self.tpv_h * self.scale_h, self.tpv_w * self.scale_w),
mode='bilinear')
if self.scale_z != 1 or self.scale_h != 1:
tpv_zh = F.interpolate(
tpv_zh,
size=(self.tpv_z * self.scale_z, self.tpv_h * self.scale_h),
mode='bilinear')
if self.scale_w != 1 or self.scale_z != 1:
tpv_wz = F.interpolate(
tpv_wz,
size=(self.tpv_w * self.scale_w, self.tpv_z * self.scale_z),
mode='bilinear')
batch_pts, batch_vox = [], []
for i, data_sample in enumerate(batch_data_samples):
point_coors = data_sample.point_coors.reshape(1, 1, -1, 3).float()
point_coors[
...,
0] = point_coors[..., 0] / (self.tpv_w * self.scale_w) * 2 - 1
point_coors[
...,
1] = point_coors[..., 1] / (self.tpv_h * self.scale_h) * 2 - 1
point_coors[
...,
2] = point_coors[..., 2] / (self.tpv_z * self.scale_z) * 2 - 1
sample_loc = point_coors[..., [0, 1]]
tpv_hw_pts = F.grid_sample(
tpv_hw[i:i + 1], sample_loc, align_corners=False)
sample_loc = point_coors[..., [1, 2]]
tpv_zh_pts = F.grid_sample(
tpv_zh[i:i + 1], sample_loc, align_corners=False)
sample_loc = point_coors[..., [2, 0]]
tpv_wz_pts = F.grid_sample(
tpv_wz[i:i + 1], sample_loc, align_corners=False)
fused_pts = (tpv_hw_pts + tpv_zh_pts +
tpv_wz_pts).squeeze(0).squeeze(1)
batch_pts.append(fused_pts)
tpv_hw_vox = tpv_hw.unsqueeze(-1).permute(0, 1, 3, 2, 4).expand(
-1, -1, -1, -1, self.scale_z * self.tpv_z)
tpv_zh_vox = tpv_zh.unsqueeze(-1).permute(0, 1, 4, 3, 2).expand(
-1, -1, self.scale_w * self.tpv_w, -1, -1)
tpv_wz_vox = tpv_wz.unsqueeze(-1).permute(0, 1, 2, 4, 3).expand(
-1, -1, -1, self.scale_h * self.tpv_h, -1)
fused_vox = tpv_hw_vox + tpv_zh_vox + tpv_wz_vox
voxel_coors = data_sample.voxel_coors.long()
fused_vox = fused_vox[:, :, voxel_coors[:, 0], voxel_coors[:, 1],
voxel_coors[:, 2]]
fused_vox = fused_vox.squeeze(0)
batch_vox.append(fused_vox)
batch_pts = torch.cat(batch_pts, dim=1)
batch_vox = torch.cat(batch_vox, dim=1)
num_points = batch_pts.shape[1]
logits = self.decoder(
torch.cat([batch_pts, batch_vox], dim=1).transpose(0, 1))
logits = self.classifier(logits)
pts_logits = logits[:num_points, :]
vox_logits = logits[num_points:, :]
pts_seg_label = torch.cat([
data_sample.gt_pts_seg.pts_semantic_mask
for data_sample in batch_data_samples
])
voxel_seg_label = torch.cat([
data_sample.gt_pts_seg.voxel_semantic_mask
for data_sample in batch_data_samples
])
if self.ce_input == 'voxel':
ce_input = vox_logits
ce_label = voxel_seg_label
else:
ce_input = pts_logits
ce_label = pts_seg_label
if self.lovasz_input == 'voxel':
lovasz_input = vox_logits
lovasz_label = voxel_seg_label
else:
lovasz_input = pts_logits
lovasz_label = pts_seg_label
loss = dict()
loss['loss_ce'] = self.loss_ce(
ce_input, ce_label, ignore_index=self.ignore_index)
loss['loss_lovasz'] = self.loss_lovasz(
lovasz_input, lovasz_label, ignore_index=self.ignore_index)
return loss
import copy
import warnings
import torch
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import (build_attention,
build_feedforward_network)
from mmengine.config import ConfigDict
from mmengine.model import BaseModule, ModuleList
from mmengine.registry import MODELS
@MODELS.register_module()
class TPVFormerLayer(BaseModule):
"""Base `TPVFormerLayer` for vision transformer.
It can be built from `mmcv.ConfigDict` and support more flexible
customization, for example, using any number of `FFN or LN ` and
use different kinds of `attention` by specifying a list of `ConfigDict`
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
when you specifying `norm` as the first element of `operation_order`.
More details about the `prenorm`: `On Layer Normalization in the
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for `self_attention` or `cross_attention` modules,
The order of the configs in the list should be consistent with
corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None.
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for FFN, The order of the configs in the list should be
consistent with corresponding ffn in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`.
Default: None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Key, Query and Value are shape
of (batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
"""
def __init__(self,
attn_cfgs=None,
ffn_cfgs=dict(
type='FFN',
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
),
operation_order=None,
norm_cfg=dict(type='LN'),
init_cfg=None,
batch_first=True,
**kwargs):
deprecated_args = dict(
feedforward_channels='feedforward_channels',
ffn_dropout='ffn_drop',
ffn_num_fcs='num_fcs')
for ori_name, new_name in deprecated_args.items():
if ori_name in kwargs:
warnings.warn(
f'The arguments `{ori_name}` in BaseTransformerLayer '
f'has been deprecated, now you should set `{new_name}` '
f'and other FFN related arguments '
f'to a dict named `ffn_cfgs`. ')
ffn_cfgs[new_name] = kwargs[ori_name]
super().__init__(init_cfg)
self.batch_first = batch_first
num_attn = operation_order.count('self_attn') + operation_order.count(
'cross_attn')
if isinstance(attn_cfgs, dict):
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
else:
assert num_attn == len(attn_cfgs), f'The length ' \
f'of attn_cfg {num_attn} is ' \
f'not consistent with the number of attention' \
f'in operation_order {operation_order}.'
self.num_attn = num_attn
self.operation_order = operation_order
self.norm_cfg = norm_cfg
self.pre_norm = operation_order[0] == 'norm'
self.attentions = ModuleList()
index = 0
for operation_name in operation_order:
if operation_name in ['self_attn', 'cross_attn']:
if 'batch_first' in attn_cfgs[index]:
assert self.batch_first == attn_cfgs[index]['batch_first']
else:
attn_cfgs[index]['batch_first'] = self.batch_first
attention = build_attention(attn_cfgs[index])
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
attention.operation_name = operation_name
self.attentions.append(attention)
index += 1
self.embed_dims = self.attentions[0].embed_dims
self.ffns = ModuleList()
num_ffns = operation_order.count('ffn')
if isinstance(ffn_cfgs, dict):
ffn_cfgs = ConfigDict(ffn_cfgs)
if isinstance(ffn_cfgs, dict):
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims
else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append(build_feedforward_network(ffn_cfgs[ffn_index]))
self.norms = ModuleList()
num_norms = operation_order.count('norm')
for _ in range(num_norms):
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
def forward(self,
query,
key=None,
value=None,
tpv_pos=None,
ref_2d=None,
tpv_h=None,
tpv_w=None,
tpv_z=None,
reference_points_cams=None,
tpv_masks=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
tpv_pos (Tensor): The positional encoding for self attn.
Returns:
Tensor: forwarded results with shape
[[bs, num_queries, embed_dims] * 3] for 3 tpv planes.
"""
norm_index = 0
attn_index = 0
ffn_index = 0
if self.operation_order[0] == 'cross_attn':
query = torch.cat(query, dim=1)
identity = query
for layer in self.operation_order:
# cross view hybrid-attention
if layer == 'self_attn':
ss = torch.tensor(
[[tpv_h, tpv_w], [tpv_z, tpv_h], [tpv_w, tpv_z]],
device=query[0].device)
lsi = torch.tensor(
[0, tpv_h * tpv_w, tpv_h * tpv_w + tpv_z * tpv_h],
device=query[0].device)
if not isinstance(query, (list, tuple)):
query = torch.split(
query, [tpv_h * tpv_w, tpv_z * tpv_h, tpv_w * tpv_z],
dim=1)
query = self.attentions[attn_index](
query,
identity if self.pre_norm else None,
query_pos=tpv_pos,
reference_points=ref_2d,
spatial_shapes=ss,
level_start_index=lsi,
**kwargs)
attn_index += 1
query = torch.cat(query, dim=1)
identity = query
elif layer == 'norm':
query = self.norms[norm_index](query)
norm_index += 1
# image cross attention
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
reference_points_cams=reference_points_cams,
tpv_masks=tpv_masks,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
**kwargs)
attn_index += 1
identity = query
elif layer == 'ffn':
query = self.ffns[ffn_index](
query, identity if self.pre_norm else None)
ffn_index += 1
query = torch.split(
query, [tpv_h * tpv_w, tpv_z * tpv_h, tpv_w * tpv_z], dim=1)
return query
_base_ = ['mmdet3d::_base_/default_runtime.py'] _base_ = ['../../../configs/_base_/default_runtime.py']
custom_imports = dict(imports=['projects.TR3D.tr3d']) custom_imports = dict(imports=['projects.TR3D.tr3d'])
model = dict( model = dict(
......
asynctest
codecov codecov
flake8 flake8
interrogate interrogate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment