Unverified Commit 762e3b53 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Feature] Support DSVT training (#2738)


Co-authored-by: default avatarJingweiZhang12 <zjw18@mails.tsinghua.edu.cn>
Co-authored-by: sjh <sunjiahao1999>
parent 5b88c7b8
......@@ -101,7 +101,7 @@ class SeparateHead(BaseModule):
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg torch.Tensor): 2D regression value with the
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
......@@ -217,7 +217,7 @@ class DCNSeparateHead(BaseModule):
Returns:
dict[str: torch.Tensor]: contains the following keys:
-reg torch.Tensor): 2D regression value with the
-reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the
shape of [B, 1, H, W].
......
......@@ -21,6 +21,10 @@ class SECONDFPN(BaseModule):
upsample_cfg (dict): Config dict of upsample layers.
conv_cfg (dict): Config dict of conv layers.
use_conv_for_no_stride (bool): Whether to use conv when stride is 1.
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`],
optional): Initialization config dict. Defaults to
[dict(type='Kaiming', layer='ConvTranspose2d'),
dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)].
"""
def __init__(self,
......@@ -31,7 +35,13 @@ class SECONDFPN(BaseModule):
upsample_cfg=dict(type='deconv', bias=False),
conv_cfg=dict(type='Conv2d', bias=False),
use_conv_for_no_stride=False,
init_cfg=None):
init_cfg=[
dict(type='Kaiming', layer='ConvTranspose2d'),
dict(
type='Constant',
layer='NaiveSyncBatchNorm2d',
val=1.0)
]):
# if for GroupNorm,
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
super(SECONDFPN, self).__init__(init_cfg=init_cfg)
......@@ -64,12 +74,6 @@ class SECONDFPN(BaseModule):
deblocks.append(deblock)
self.deblocks = nn.ModuleList(deblocks)
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='ConvTranspose2d'),
dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)
]
def forward(self, x):
"""Forward function.
......
......@@ -275,12 +275,13 @@ class BaseInstance3DBoxes:
Tensor: A binary vector indicating whether each point is inside the
reference range.
"""
in_range_flags = ((self.tensor[:, 0] > box_range[0])
& (self.tensor[:, 1] > box_range[1])
& (self.tensor[:, 2] > box_range[2])
& (self.tensor[:, 0] < box_range[3])
& (self.tensor[:, 1] < box_range[4])
& (self.tensor[:, 2] < box_range[5]))
gravity_center = self.gravity_center
in_range_flags = ((gravity_center[:, 0] > box_range[0])
& (gravity_center[:, 1] > box_range[1])
& (gravity_center[:, 2] > box_range[2])
& (gravity_center[:, 0] < box_range[3])
& (gravity_center[:, 1] < box_range[4])
& (gravity_center[:, 2] < box_range[5]))
return in_range_flags
@abstractmethod
......
......@@ -57,17 +57,25 @@ python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-
### Training commands
The support of training DSVT is on the way.
In MMDetection3D's root directory, run the following command to test the model:
```bash
tools/dist_train.sh projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py 8 --sync_bn torch
```
## Results and models
### Waymo
| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) | ✓ | × | | | 75.2 | 72.2 | 68.9 | 66.1 | |
| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :----: | :-----: | :----: | :---------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) | ✓ | × | 75.5 | 72.4 | 69.2 | 66.3 | \[log\](\<https://download.openmmlab.com/mmdetection3d/v1.1.0_models/dsvt/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class_20230917_102130.log) |
**Note**:
- `ResSECOND` denotes the base block in SECOND has residual layers.
**Note** that `ResSECOND` denotes the base block in SECOND has residual layers.
- Regrettably, we are unable to provide the pre-trained model weights due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/), so we only provide the training logs as shown above.
## Citation
......
......@@ -88,25 +88,28 @@ model = dict(
loss_cls=dict(
type='mmdet.GaussianFocalLoss', reduction='mean', loss_weight=1.0),
loss_bbox=dict(type='mmdet.L1Loss', reduction='mean', loss_weight=2.0),
loss_iou=dict(type='mmdet.L1Loss', reduction='sum', loss_weight=1.0),
loss_reg_iou=dict(
type='mmdet3d.DIoU3DLoss', reduction='mean', loss_weight=2.0),
norm_bbox=True),
# model training and testing settings
train_cfg=dict(
pts=dict(
grid_size=grid_size,
voxel_size=voxel_size,
out_size_factor=4,
point_cloud_range=point_cloud_range,
out_size_factor=1,
dense_reg=1,
gaussian_overlap=0.1,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])),
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
test_cfg=dict(
max_per_img=500,
max_pool_nms=False,
min_radius=[4, 12, 10, 1, 0.85, 0.175],
iou_rectifier=[[0.68, 0.71, 0.65]],
pc_range=[-80, -80],
out_size_factor=4,
out_size_factor=1,
voxel_size=voxel_size[:2],
nms_type='rotate',
multi_class_nms=True,
......@@ -128,6 +131,8 @@ db_sampler = dict(
coord_type='LIDAR',
load_dim=6,
use_dim=[0, 1, 2, 3, 4],
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args),
backend_args=backend_args)
......@@ -138,25 +143,22 @@ train_pipeline = [
load_dim=6,
use_dim=5,
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args),
# Add this if using `MultiFrameDeformableDecoderRPN`
# dict(
# type='LoadPointsFromMultiSweeps',
# sweeps_num=9,
# load_dim=6,
# use_dim=[0, 1, 2, 3, 4],
# pad_empty_sweeps=True,
# remove_close=True),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.5, 0.5, 0]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
translation_std=[0.5, 0.5, 0.5]),
dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter3D', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
......@@ -172,25 +174,34 @@ test_pipeline = [
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args),
dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
type='Pack3DDetInputs',
keys=['points'],
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
]
dataset_type = 'WaymoDataset'
train_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='waymo_infos_train.pkl',
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
pipeline=train_pipeline,
modality=input_modality,
test_mode=False,
metainfo=metainfo,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR',
# load one frame every five frames
load_interval=5,
backend_args=backend_args))
val_dataloader = dict(
batch_size=4,
num_workers=4,
......@@ -212,18 +223,59 @@ test_dataloader = val_dataloader
val_evaluator = dict(
type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
backend_args=backend_args,
convert_kitti_format=False,
idx2metainfo='./data/waymo/waymo_format/idx2metainfo.pkl')
result_prefix='./dsvt_pred')
test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
# schedules
lr = 1e-5
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.05, betas=(0.9, 0.99)),
clip_grad=dict(max_norm=10, norm_type=2))
param_scheduler = [
dict(
type='CosineAnnealingLR',
T_max=1.2,
eta_min=lr * 100,
begin=0,
end=1.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=10.8,
eta_min=lr * 1e-4,
begin=1.2,
end=12,
by_epoch=True,
convert_to_iter_based=True),
# momentum scheduler
dict(
type='CosineAnnealingMomentum',
T_max=1.2,
eta_min=0.85,
begin=0,
end=1.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
T_max=10.8,
eta_min=0.95,
begin=1.2,
end=12,
by_epoch=True,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=12, val_interval=1)
# runtime settings
val_cfg = dict()
test_cfg = dict()
......@@ -236,4 +288,12 @@ test_cfg = dict()
default_hooks = dict(
logger=dict(type='LoggerHook', interval=50),
checkpoint=dict(type='CheckpointHook', interval=5))
checkpoint=dict(type='CheckpointHook', interval=1))
custom_hooks = [
dict(
type='DisableAugHook',
disable_after_epoch=11,
disable_aug_list=[
'GlobalRotScaleTrans', 'RandomFlip3D', 'ObjectSample'
])
]
from .disable_aug_hook import DisableAugHook
from .dsvt import DSVT
from .dsvt_head import DSVTCenterHead
from .dsvt_transformer import DSVTMiddleEncoder
from .dynamic_pillar_vfe import DynamicPillarVFE3D
from .map2bev import PointPillarsScatter3D
from .res_second import ResSECOND
from .transforms_3d import ObjectRangeFilter3D, PointsRangeFilter3D
from .utils import DSVTBBoxCoder
__all__ = [
'DSVTCenterHead', 'DSVT', 'DSVTMiddleEncoder', 'DynamicPillarVFE3D',
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder'
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder',
'ObjectRangeFilter3D', 'PointsRangeFilter3D', 'DisableAugHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmengine.dataset import BaseDataset
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner
from mmdet3d.registry import HOOKS
@HOOKS.register_module()
class DisableAugHook(Hook):
"""The hook of disabling augmentations during training.
Args:
disable_after_epoch (int): The number of epochs after which
the data augmentation will be closed in the training.
Defaults to 15.
disable_aug_list (list): the list of data augmentation will
be closed in the training. Defaults to [].
"""
def __init__(self,
disable_after_epoch: int = 15,
disable_aug_list: List = []):
self.disable_after_epoch = disable_after_epoch
self.disable_aug_list = disable_aug_list
self._restart_dataloader = False
def before_train_epoch(self, runner: Runner):
"""Close augmentation.
Args:
runner (Runner): The runner.
"""
epoch = runner.epoch
train_loader = runner.train_dataloader
model = runner.model
# TODO: refactor after mmengine using model wrapper
if is_model_wrapper(model):
model = model.module
if epoch == self.disable_after_epoch:
dataset = runner.train_dataloader.dataset
# handle dataset wrapper
if not isinstance(dataset, BaseDataset):
dataset = dataset.dataset
new_transforms = []
for transform in dataset.pipeline.transforms: # noqa: E501
if transform.__class__.__name__ not in self.disable_aug_list:
new_transforms.append(transform)
else:
runner.logger.info(
f'Disable {transform.__class__.__name__}')
dataset.pipeline.transforms = new_transforms
# The dataset pipeline cannot be updated when persistent_workers
# is True, so we need to force the dataloader's multi-process
# restart. This is a very hacky approach.
if hasattr(train_loader, 'persistent_workers'
) and train_loader.persistent_workers is True:
train_loader._DataLoader__initialized = False
train_loader._iterator = None
self._restart_dataloader = True
else:
# Once the restart is complete, we need to restore
# the initialization flag.
if self._restart_dataloader:
train_loader._DataLoader__initialized = True
......@@ -103,7 +103,11 @@ class DSVT(Base3DDetector):
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
pass
pts_feats = self.extract_feat(batch_inputs_dict)
losses = dict()
loss = self.bbox_head.loss(pts_feats, batch_data_samples)
losses.update(loss)
return losses
def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample],
......
import math
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from mmcv.ops import boxes_iou3d
from mmdet.models.utils import multi_apply
from mmengine.model import kaiming_init
from mmengine.structures import InstanceData
from torch import Tensor
from torch.nn.init import constant_
from mmdet3d.models import CenterHead
from mmdet3d.models.layers import circle_nms, nms_bev
from mmdet3d.models.utils import (clip_sigmoid, draw_heatmap_gaussian,
gaussian_radius)
from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample, xywhr2xyxyr
......@@ -18,8 +25,33 @@ class DSVTCenterHead(CenterHead):
This head adds IoU prediction branch based on the original CenterHead.
"""
def __init__(self, *args, **kwargs):
def __init__(self,
loss_iou=dict(
type='mmdet.L1Loss', reduction='mean', loss_weight=1),
loss_reg_iou=None,
*args,
**kwargs):
super(DSVTCenterHead, self).__init__(*args, **kwargs)
self.loss_iou = MODELS.build(loss_iou)
self.loss_iou_reg = MODELS.build(
loss_reg_iou) if loss_reg_iou is not None else None
def init_weights(self):
kaiming_init(
self.shared_conv.conv,
a=math.sqrt(5),
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform')
for head in self.task_heads[0].heads:
if head == 'heatmap':
constant_(self.task_heads[0].__getattr__(head)[-1].bias,
self.task_heads[0].init_bias)
else:
for m in self.task_heads[0].__getattr__(head).modules():
if isinstance(m, nn.Conv2d):
kaiming_init(
m, mode='fan_in', nonlinearity='leaky_relu')
def forward_single(self, x: Tensor) -> dict:
"""Forward function for CenterPoint.
......@@ -66,7 +98,298 @@ class DSVTCenterHead(CenterHead):
Returns:
dict: Losses of each branch.
"""
pass
outs = self(pts_feats)
batch_gt_instance_3d = []
for data_sample in batch_data_samples:
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
losses = self.loss_by_feat(outs, batch_gt_instance_3d)
return losses
def _decode_all_preds(self,
pred_dict,
point_cloud_range=None,
voxel_size=None):
batch_size, _, H, W = pred_dict['reg'].shape
batch_center = pred_dict['reg'].permute(0, 2, 3, 1).contiguous().view(
batch_size, H * W, 2) # (B, H, W, 2)
batch_center_z = pred_dict['height'].permute(
0, 2, 3, 1).contiguous().view(batch_size, H * W, 1) # (B, H, W, 1)
batch_dim = pred_dict['dim'].exp().permute(
0, 2, 3, 1).contiguous().view(batch_size, H * W, 3) # (B, H, W, 3)
batch_rot_cos = pred_dict['rot'][:, 0].unsqueeze(dim=1).permute(
0, 2, 3, 1).contiguous().view(batch_size, H * W, 1) # (B, H, W, 1)
batch_rot_sin = pred_dict['rot'][:, 1].unsqueeze(dim=1).permute(
0, 2, 3, 1).contiguous().view(batch_size, H * W, 1) # (B, H, W, 1)
batch_vel = pred_dict['vel'].permute(0, 2, 3, 1).contiguous().view(
batch_size, H * W, 2) if 'vel' in pred_dict.keys() else None
angle = torch.atan2(batch_rot_sin, batch_rot_cos) # (B, H*W, 1)
ys, xs = torch.meshgrid([
torch.arange(
0, H, device=batch_center.device, dtype=batch_center.dtype),
torch.arange(
0, W, device=batch_center.device, dtype=batch_center.dtype)
])
ys = ys.view(1, H, W).repeat(batch_size, 1, 1)
xs = xs.view(1, H, W).repeat(batch_size, 1, 1)
xs = xs.view(batch_size, -1, 1) + batch_center[:, :, 0:1]
ys = ys.view(batch_size, -1, 1) + batch_center[:, :, 1:2]
xs = xs * voxel_size[0] + point_cloud_range[0]
ys = ys * voxel_size[1] + point_cloud_range[1]
box_part_list = [xs, ys, batch_center_z, batch_dim, angle]
if batch_vel is not None:
box_part_list.append(batch_vel)
box_preds = torch.cat((box_part_list),
dim=-1).view(batch_size, H, W, -1)
return box_preds
def _transpose_and_gather_feat(self, feat, ind):
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = self._gather_feat(feat, ind)
return feat
def calc_iou_loss(self, iou_preds, batch_box_preds, mask, ind, gt_boxes):
"""
Args:
iou_preds: (batch x 1 x h x w)
batch_box_preds: (batch x (7 or 9) x h x w)
mask: (batch x max_objects)
ind: (batch x max_objects)
gt_boxes: List of batch groundtruth boxes.
Returns:
Tensor: IoU Loss.
"""
if mask.sum() == 0:
return iou_preds.new_zeros((1))
mask = mask.bool()
selected_iou_preds = self._transpose_and_gather_feat(iou_preds,
ind)[mask]
selected_box_preds = self._transpose_and_gather_feat(
batch_box_preds, ind)[mask]
gt_boxes = torch.cat(gt_boxes)
assert gt_boxes.size(0) == selected_box_preds.size(0)
iou_target = boxes_iou3d(selected_box_preds[:, 0:7], gt_boxes[:, 0:7])
iou_target = torch.diag(iou_target).view(-1)
iou_target = iou_target * 2 - 1 # [0, 1] ==> [-1, 1]
loss = self.loss_iou(selected_iou_preds.view(-1), iou_target)
loss = loss / torch.clamp(mask.sum(), min=1e-4)
return loss
def calc_iou_reg_loss(self, batch_box_preds, mask, ind, gt_boxes):
if mask.sum() == 0:
return batch_box_preds.new_zeros((1))
mask = mask.bool()
selected_box_preds = self._transpose_and_gather_feat(
batch_box_preds, ind)[mask]
gt_boxes = torch.cat(gt_boxes)
assert gt_boxes.size(0) == selected_box_preds.size(0)
loss = self.loss_iou_reg(selected_box_preds[:, 0:7], gt_boxes[:, 0:7])
return loss
def get_targets(
self,
batch_gt_instances_3d: List[InstanceData],
) -> Tuple[List[Tensor]]:
"""Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and\
``labels_3d`` attributes.
Returns:
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the
position of the valid boxes.
- list[torch.Tensor]: Masks indicating which
boxes are valid.
"""
heatmaps, anno_boxes, inds, masks, task_gt_bboxes = multi_apply(
self.get_targets_single, batch_gt_instances_3d)
# Transpose heatmaps
heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps]
# Transpose anno_boxes
anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes]
# Transpose inds
inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds]
# Transpose masks
masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks]
# Transpose task_gt_bboxes
task_gt_bboxes = list(map(list, zip(*task_gt_bboxes)))
return heatmaps, anno_boxes, inds, masks, task_gt_bboxes
def get_targets_single(self,
gt_instances_3d: InstanceData) -> Tuple[Tensor]:
"""Generate training targets for a single sample.
Args:
gt_instances_3d (:obj:`InstanceData`): Gt_instances_3d of
single data sample. It usually includes
``bboxes_3d`` and ``labels_3d`` attributes.
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the position
of the valid boxes.
- list[torch.Tensor]: Masks indicating which boxes
are valid.
"""
gt_labels_3d = gt_instances_3d.labels_3d
gt_bboxes_3d = gt_instances_3d.bboxes_3d
device = gt_labels_3d.device
gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]),
dim=1).to(device)
max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg']
grid_size = torch.tensor(self.train_cfg['grid_size']).to(device)
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor']
# reorganize the gt_dict by tasks
task_masks = []
flag = 0
for class_name in self.class_names:
task_masks.append([
torch.where(gt_labels_3d == class_name.index(i) + flag)
for i in class_name
])
flag += len(class_name)
task_boxes = []
task_classes = []
flag2 = 0
for idx, mask in enumerate(task_masks):
task_box = []
task_class = []
for m in mask:
task_box.append(gt_bboxes_3d[m])
# 0 is background for each task, so we need to add 1 here.
task_class.append(gt_labels_3d[m] + 1 - flag2)
task_boxes.append(torch.cat(task_box, axis=0).to(device))
task_classes.append(torch.cat(task_class).long().to(device))
flag2 += len(mask)
draw_gaussian = draw_heatmap_gaussian
heatmaps, anno_boxes, inds, masks = [], [], [], []
for idx, task_head in enumerate(self.task_heads):
heatmap = gt_bboxes_3d.new_zeros(
(len(self.class_names[idx]), feature_map_size[1],
feature_map_size[0]))
anno_box = gt_bboxes_3d.new_zeros((max_objs, 8),
dtype=torch.float32)
ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64)
mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8)
num_objs = min(task_boxes[idx].shape[0], max_objs)
for k in range(num_objs):
cls_id = task_classes[idx][k] - 1
length = task_boxes[idx][k][3]
width = task_boxes[idx][k][4]
length = length / voxel_size[0] / self.train_cfg[
'out_size_factor']
width = width / voxel_size[1] / self.train_cfg[
'out_size_factor']
if width > 0 and length > 0:
radius = gaussian_radius(
(width, length),
min_overlap=self.train_cfg['gaussian_overlap'])
radius = max(self.train_cfg['min_radius'], int(radius))
# be really careful for the coordinate system of
# your box annotation.
x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][
1], task_boxes[idx][k][2]
coor_x = (
x - pc_range[0]
) / voxel_size[0] / self.train_cfg['out_size_factor']
coor_y = (
y - pc_range[1]
) / voxel_size[1] / self.train_cfg['out_size_factor']
center = torch.tensor([coor_x, coor_y],
dtype=torch.float32,
device=device)
center_int = center.to(torch.int32)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if not (0 <= center_int[0] < feature_map_size[0]
and 0 <= center_int[1] < feature_map_size[1]):
continue
draw_gaussian(heatmap[cls_id], center_int, radius)
new_idx = k
x, y = center_int[0], center_int[1]
assert (y * feature_map_size[0] + x <
feature_map_size[0] * feature_map_size[1])
ind[new_idx] = y * feature_map_size[0] + x
mask[new_idx] = 1
# TODO: support other outdoor dataset
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
if self.norm_bbox:
box_dim = box_dim.log()
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.cos(rot).unsqueeze(0),
torch.sin(rot).unsqueeze(0)
])
heatmaps.append(heatmap)
anno_boxes.append(anno_box)
masks.append(mask)
inds.append(ind)
return heatmaps, anno_boxes, inds, masks, task_boxes
def loss_by_feat(self, preds_dicts: Tuple[List[dict]],
batch_gt_instances_3d: List[InstanceData], *args,
......@@ -79,13 +402,72 @@ class DSVTCenterHead(CenterHead):
tasks head, and the internal list indicate different
FPN level.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and\
gt_instances_3d. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes.
Returns:
dict[str,torch.Tensor]: Loss of heatmap and bbox of each task.
"""
pass
heatmaps, anno_boxes, inds, masks, task_gt_bboxes = self.get_targets(
batch_gt_instances_3d)
loss_dict = dict()
for task_id, preds_dict in enumerate(preds_dicts):
# heatmap focal loss
preds_dict[0]['heatmap'] = clip_sigmoid(preds_dict[0]['heatmap'])
num_pos = heatmaps[task_id].eq(1).float().sum().item()
loss_heatmap = self.loss_cls(
preds_dict[0]['heatmap'],
heatmaps[task_id],
avg_factor=max(num_pos, 1))
target_box = anno_boxes[task_id]
# reconstruct the anno_box from multiple reg heads
preds_dict[0]['anno_box'] = torch.cat(
(preds_dict[0]['reg'], preds_dict[0]['height'],
preds_dict[0]['dim'], preds_dict[0]['rot']),
dim=1)
# Regression loss for dimension, offset, height, rotation
ind = inds[task_id]
num = masks[task_id].float().sum()
pred = preds_dict[0]['anno_box'].permute(0, 2, 3, 1).contiguous()
pred = pred.view(pred.size(0), -1, pred.size(3))
pred = self._gather_feat(pred, ind)
mask = masks[task_id].unsqueeze(2).expand_as(target_box).float()
isnotnan = (~torch.isnan(target_box)).float()
mask *= isnotnan
code_weights = self.train_cfg.get('code_weights', None)
bbox_weights = mask * mask.new_tensor(code_weights)
loss_bbox = self.loss_bbox(
pred, target_box, bbox_weights, avg_factor=(num + 1e-4))
loss_dict[f'task{task_id}.loss_heatmap'] = loss_heatmap
loss_dict[f'task{task_id}.loss_bbox'] = loss_bbox
if 'iou' in preds_dict[0]:
batch_box_preds = self._decode_all_preds(
pred_dict=preds_dict[0],
point_cloud_range=self.train_cfg['point_cloud_range'],
voxel_size=self.train_cfg['voxel_size']
) # (B, H, W, 7 or 9)
batch_box_preds_for_iou = batch_box_preds.permute(
0, 3, 1, 2) # (B, 7 or 9, H, W)
loss_dict[f'task{task_id}.loss_iou'] = self.calc_iou_loss(
iou_preds=preds_dict[0]['iou'],
batch_box_preds=batch_box_preds_for_iou.clone().detach(),
mask=masks[task_id],
ind=ind,
gt_boxes=task_gt_bboxes[task_id])
if self.loss_iou_reg is not None:
loss_dict[f'task{task_id}.loss_reg_iou'] = \
self.calc_iou_reg_loss(
batch_box_preds=batch_box_preds_for_iou,
mask=masks[task_id],
ind=ind,
gt_boxes=task_gt_bboxes[task_id])
return loss_dict
def predict(self,
pts_feats: Tuple[torch.Tensor],
......@@ -158,6 +540,7 @@ class DSVTCenterHead(CenterHead):
else:
batch_dim = preds_dict[0]['dim']
# It's different from CenterHead
batch_rotc = preds_dict[0]['rot'][:, 0].unsqueeze(1)
batch_rots = preds_dict[0]['rot'][:, 1].unsqueeze(1)
batch_iou = (preds_dict[0]['iou'] +
......
# modified from https://github.com/Haiyang-W/DSVT
import numpy as np
import torch
import torch.nn as nn
import torch_scatter
......@@ -76,6 +77,7 @@ class DynamicPillarVFE3D(nn.Module):
self.voxel_x = voxel_size[0]
self.voxel_y = voxel_size[1]
self.voxel_z = voxel_size[2]
point_cloud_range = np.array(point_cloud_range).astype(np.float32)
self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
self.z_offset = self.voxel_z / 2 + point_cloud_range[2]
......
# modified from https://github.com/Haiyang-W/DSVT
import warnings
from typing import Optional, Sequence, Tuple
from typing import Sequence, Tuple
from mmengine.model import BaseModule
from torch import Tensor
......@@ -78,8 +76,8 @@ class ResSECOND(BaseModule):
out_channels (list[int]): Output channels for multi-scale feature maps.
blocks_nums (list[int]): Number of blocks in each stage.
layer_strides (list[int]): Strides of each stage.
norm_cfg (dict): Config dict of normalization layers.
conv_cfg (dict): Config dict of convolutional layers.
init_cfg (dict, optional): Config for weight initialization.
Defaults to None.
"""
def __init__(self,
......@@ -87,8 +85,7 @@ class ResSECOND(BaseModule):
out_channels: Sequence[int] = [128, 128, 256],
blocks_nums: Sequence[int] = [1, 2, 2],
layer_strides: Sequence[int] = [2, 2, 2],
init_cfg: OptMultiConfig = None,
pretrained: Optional[str] = None) -> None:
init_cfg: OptMultiConfig = None) -> None:
super(ResSECOND, self).__init__(init_cfg=init_cfg)
assert len(layer_strides) == len(blocks_nums)
assert len(out_channels) == len(blocks_nums)
......@@ -108,14 +105,6 @@ class ResSECOND(BaseModule):
BasicResBlock(out_channels[i], out_channels[i]))
blocks.append(nn.Sequential(*cur_layers))
self.blocks = nn.Sequential(*blocks)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
else:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
"""Forward function.
......
from typing import List
import numpy as np
from mmcv import BaseTransform
from mmdet3d.registry import TRANSFORMS
@TRANSFORMS.register_module()
class ObjectRangeFilter3D(BaseTransform):
"""Filter objects by the range. It differs from `ObjectRangeFilter` by
using `in_range_3d` instead of `in_range_bev`.
Required Keys:
- gt_bboxes_3d
Modified Keys:
- gt_bboxes_3d
Args:
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range: List[float]) -> None:
self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def transform(self, input_dict: dict) -> dict:
"""Transform function to filter objects by the range.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
keys are updated in the result dict.
"""
gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d']
mask = gt_bboxes_3d.in_range_3d(self.pcd_range)
gt_bboxes_3d = gt_bboxes_3d[mask]
# mask is a torch tensor but gt_labels_3d is still numpy array
# using mask to index gt_labels_3d will cause bug when
# len(gt_labels_3d) == 1, where mask=1 will be interpreted
# as gt_labels_3d[1] and cause out of index error
gt_labels_3d = gt_labels_3d[mask.numpy().astype(bool)]
# limit rad to [-pi, pi]
gt_bboxes_3d.limit_yaw(offset=0.5, period=2 * np.pi)
input_dict['gt_bboxes_3d'] = gt_bboxes_3d
input_dict['gt_labels_3d'] = gt_labels_3d
return input_dict
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(point_cloud_range={self.pcd_range.tolist()})'
return repr_str
@TRANSFORMS.register_module()
class PointsRangeFilter3D(BaseTransform):
"""Filter points by the range. It differs from `PointRangeFilter` by using
`in_range_bev` instead of `in_range_3d`.
Required Keys:
- points
- pts_instance_mask (optional)
Modified Keys:
- points
- pts_instance_mask (optional)
Args:
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range: List[float]) -> None:
self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def transform(self, input_dict: dict) -> dict:
"""Transform function to filter points by the range.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = input_dict['points']
points_mask = points.in_range_bev(self.pcd_range[[0, 1, 3, 4]])
clean_points = points[points_mask]
input_dict['points'] = clean_points
points_mask = points_mask.numpy()
pts_instance_mask = input_dict.get('pts_instance_mask', None)
pts_semantic_mask = input_dict.get('pts_semantic_mask', None)
if pts_instance_mask is not None:
input_dict['pts_instance_mask'] = pts_instance_mask[points_mask]
if pts_semantic_mask is not None:
input_dict['pts_semantic_mask'] = pts_semantic_mask[points_mask]
return input_dict
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(point_cloud_range={self.pcd_range.tolist()})'
return repr_str
......@@ -3,10 +3,11 @@ from typing import Dict, List, Optional
import numpy as np
import torch
import torch.nn as nn
from mmdet.models.losses.utils import weighted_loss
from torch import Tensor
from mmdet3d.models.task_modules import CenterPointBBoxCoder
from mmdet3d.registry import TASK_UTILS
from mmdet3d.registry import MODELS, TASK_UTILS
from .ops.ingroup_inds.ingroup_inds_op import ingroup_inds
get_inner_win_inds_cuda = ingroup_inds
......@@ -266,7 +267,7 @@ class DSVTBBoxCoder(CenterPointBBoxCoder):
thresh_mask = final_scores > self.score_threshold
if self.post_center_range is not None:
self.post_center_range = torch.tensor(
self.post_center_range = torch.as_tensor(
self.post_center_range, device=heat.device)
mask = (final_box_preds[..., :3] >=
self.post_center_range[:3]).all(2)
......@@ -298,3 +299,142 @@ class DSVTBBoxCoder(CenterPointBBoxCoder):
'support post_center_range is not None for now!')
return predictions_dicts
def center_to_corner2d(center, dim):
corners_norm = torch.tensor(
[[-0.5, -0.5], [-0.5, 0.5], [0.5, 0.5], [0.5, -0.5]],
device=dim.device).type_as(center) # (4, 2)
corners = dim.view([-1, 1, 2]) * corners_norm.view([1, 4, 2]) # (N, 4, 2)
corners = corners + center.view(-1, 1, 2)
return corners
@weighted_loss
def diou3d_loss(pred_boxes, gt_boxes, eps: float = 1e-7):
"""
modified from https://github.com/agent-sgs/PillarNet/blob/master/det3d/core/utils/center_utils.py # noqa
Args:
pred_boxes (N, 7):
gt_boxes (N, 7):
Returns:
Tensor: Distance-IoU Loss.
"""
assert pred_boxes.shape[0] == gt_boxes.shape[0]
qcorners = center_to_corner2d(pred_boxes[:, :2],
pred_boxes[:, 3:5]) # (N, 4, 2)
gcorners = center_to_corner2d(gt_boxes[:, :2], gt_boxes[:,
3:5]) # (N, 4, 2)
inter_max_xy = torch.minimum(qcorners[:, 2], gcorners[:, 2])
inter_min_xy = torch.maximum(qcorners[:, 0], gcorners[:, 0])
out_max_xy = torch.maximum(qcorners[:, 2], gcorners[:, 2])
out_min_xy = torch.minimum(qcorners[:, 0], gcorners[:, 0])
# calculate area
volume_pred_boxes = pred_boxes[:, 3] * pred_boxes[:, 4] * pred_boxes[:, 5]
volume_gt_boxes = gt_boxes[:, 3] * gt_boxes[:, 4] * gt_boxes[:, 5]
inter_h = torch.minimum(
pred_boxes[:, 2] + 0.5 * pred_boxes[:, 5],
gt_boxes[:, 2] + 0.5 * gt_boxes[:, 5]) - torch.maximum(
pred_boxes[:, 2] - 0.5 * pred_boxes[:, 5],
gt_boxes[:, 2] - 0.5 * gt_boxes[:, 5])
inter_h = torch.clamp(inter_h, min=0)
inter = torch.clamp((inter_max_xy - inter_min_xy), min=0)
volume_inter = inter[:, 0] * inter[:, 1] * inter_h
volume_union = volume_gt_boxes + volume_pred_boxes - volume_inter + eps
# boxes_iou3d_gpu(pred_boxes, gt_boxes)
inter_diag = torch.pow(gt_boxes[:, 0:3] - pred_boxes[:, 0:3], 2).sum(-1)
outer_h = torch.maximum(
gt_boxes[:, 2] + 0.5 * gt_boxes[:, 5],
pred_boxes[:, 2] + 0.5 * pred_boxes[:, 5]) - torch.minimum(
gt_boxes[:, 2] - 0.5 * gt_boxes[:, 5],
pred_boxes[:, 2] - 0.5 * pred_boxes[:, 5])
outer_h = torch.clamp(outer_h, min=0)
outer = torch.clamp((out_max_xy - out_min_xy), min=0)
outer_diag = outer[:, 0]**2 + outer[:, 1]**2 + outer_h**2 + eps
dious = volume_inter / volume_union - inter_diag / outer_diag
dious = torch.clamp(dious, min=-1.0, max=1.0)
loss = 1 - dious
return loss
@MODELS.register_module()
class DIoU3DLoss(nn.Module):
r"""3D bboxes Implementation of `Distance-IoU Loss: Faster and Better
Learning for Bounding Box Regression <https://arxiv.org/abs/1911.08287>`_.
Code is modified from https://github.com/Zzh-tju/DIoU.
Args:
eps (float): Epsilon to avoid log(0). Defaults to 1e-6.
reduction (str): Options are "none", "mean" and "sum".
Defaults to "mean".
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(self,
eps: float = 1e-6,
reduction: str = 'mean',
loss_weight: float = 1.0) -> None:
super().__init__()
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
avg_factor: Optional[int] = None,
reduction_override: Optional[str] = None,
**kwargs) -> Tensor:
"""Forward function.
Args:
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
shape (n, 4).
target (Tensor): The learning target of the prediction,
shape (n, 4).
weight (Optional[Tensor], optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (Optional[int], optional): Average factor that is used
to average the loss. Defaults to None.
reduction_override (Optional[str], optional): The reduction method
used to override the original reduction method of the loss.
Defaults to None. Options are "none", "mean" and "sum".
Returns:
Tensor: Loss tensor.
"""
if weight is not None and not torch.any(weight > 0):
if pred.dim() == weight.dim() + 1:
weight = weight.unsqueeze(1)
return (pred * weight).sum() # 0
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if weight is not None and weight.dim() > 1:
# TODO: remove this in the future
# reduce the weight of shape (n, 4) to (n,) to match the
# giou_loss of shape (n,)
assert weight.shape == pred.shape
weight = weight.mean(-1)
loss = self.loss_weight * diou3d_loss(
pred,
target,
weight,
eps=self.eps,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss
......@@ -21,6 +21,12 @@ def parse_args():
action='store_true',
default=False,
help='enable automatic-mixed-precision training')
parser.add_argument(
'--sync_bn',
choices=['none', 'torch', 'mmcv'],
default='none',
help='convert all BatchNorm layers in the model to SyncBatchNorm '
'(SyncBN) or mmcv.ops.sync_bn.SyncBatchNorm (MMSyncBN) layers.')
parser.add_argument(
'--auto-scale-lr',
action='store_true',
......@@ -98,6 +104,10 @@ def main():
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic'
# convert BatchNorm layers
if args.sync_bn != 'none':
cfg.sync_bn = args.sync_bn
# enable automatically scaling LR
if args.auto_scale_lr:
if 'auto_scale_lr' in cfg and \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment