Unverified Commit 762e3b53 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Feature] Support DSVT training (#2738)


Co-authored-by: default avatarJingweiZhang12 <zjw18@mails.tsinghua.edu.cn>
Co-authored-by: sjh <sunjiahao1999>
parent 5b88c7b8
...@@ -101,7 +101,7 @@ class SeparateHead(BaseModule): ...@@ -101,7 +101,7 @@ class SeparateHead(BaseModule):
Returns: Returns:
dict[str: torch.Tensor]: contains the following keys: dict[str: torch.Tensor]: contains the following keys:
-reg torch.Tensor): 2D regression value with the -reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W]. shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the -height (torch.Tensor): Height value with the
shape of [B, 1, H, W]. shape of [B, 1, H, W].
...@@ -217,7 +217,7 @@ class DCNSeparateHead(BaseModule): ...@@ -217,7 +217,7 @@ class DCNSeparateHead(BaseModule):
Returns: Returns:
dict[str: torch.Tensor]: contains the following keys: dict[str: torch.Tensor]: contains the following keys:
-reg torch.Tensor): 2D regression value with the -reg (torch.Tensor): 2D regression value with the
shape of [B, 2, H, W]. shape of [B, 2, H, W].
-height (torch.Tensor): Height value with the -height (torch.Tensor): Height value with the
shape of [B, 1, H, W]. shape of [B, 1, H, W].
......
...@@ -21,6 +21,10 @@ class SECONDFPN(BaseModule): ...@@ -21,6 +21,10 @@ class SECONDFPN(BaseModule):
upsample_cfg (dict): Config dict of upsample layers. upsample_cfg (dict): Config dict of upsample layers.
conv_cfg (dict): Config dict of conv layers. conv_cfg (dict): Config dict of conv layers.
use_conv_for_no_stride (bool): Whether to use conv when stride is 1. use_conv_for_no_stride (bool): Whether to use conv when stride is 1.
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`],
optional): Initialization config dict. Defaults to
[dict(type='Kaiming', layer='ConvTranspose2d'),
dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)].
""" """
def __init__(self, def __init__(self,
...@@ -31,7 +35,13 @@ class SECONDFPN(BaseModule): ...@@ -31,7 +35,13 @@ class SECONDFPN(BaseModule):
upsample_cfg=dict(type='deconv', bias=False), upsample_cfg=dict(type='deconv', bias=False),
conv_cfg=dict(type='Conv2d', bias=False), conv_cfg=dict(type='Conv2d', bias=False),
use_conv_for_no_stride=False, use_conv_for_no_stride=False,
init_cfg=None): init_cfg=[
dict(type='Kaiming', layer='ConvTranspose2d'),
dict(
type='Constant',
layer='NaiveSyncBatchNorm2d',
val=1.0)
]):
# if for GroupNorm, # if for GroupNorm,
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True) # cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
super(SECONDFPN, self).__init__(init_cfg=init_cfg) super(SECONDFPN, self).__init__(init_cfg=init_cfg)
...@@ -64,12 +74,6 @@ class SECONDFPN(BaseModule): ...@@ -64,12 +74,6 @@ class SECONDFPN(BaseModule):
deblocks.append(deblock) deblocks.append(deblock)
self.deblocks = nn.ModuleList(deblocks) self.deblocks = nn.ModuleList(deblocks)
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='ConvTranspose2d'),
dict(type='Constant', layer='NaiveSyncBatchNorm2d', val=1.0)
]
def forward(self, x): def forward(self, x):
"""Forward function. """Forward function.
......
...@@ -275,12 +275,13 @@ class BaseInstance3DBoxes: ...@@ -275,12 +275,13 @@ class BaseInstance3DBoxes:
Tensor: A binary vector indicating whether each point is inside the Tensor: A binary vector indicating whether each point is inside the
reference range. reference range.
""" """
in_range_flags = ((self.tensor[:, 0] > box_range[0]) gravity_center = self.gravity_center
& (self.tensor[:, 1] > box_range[1]) in_range_flags = ((gravity_center[:, 0] > box_range[0])
& (self.tensor[:, 2] > box_range[2]) & (gravity_center[:, 1] > box_range[1])
& (self.tensor[:, 0] < box_range[3]) & (gravity_center[:, 2] > box_range[2])
& (self.tensor[:, 1] < box_range[4]) & (gravity_center[:, 0] < box_range[3])
& (self.tensor[:, 2] < box_range[5])) & (gravity_center[:, 1] < box_range[4])
& (gravity_center[:, 2] < box_range[5]))
return in_range_flags return in_range_flags
@abstractmethod @abstractmethod
......
...@@ -57,17 +57,25 @@ python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1- ...@@ -57,17 +57,25 @@ python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-
### Training commands ### Training commands
The support of training DSVT is on the way. In MMDetection3D's root directory, run the following command to test the model:
```bash
tools/dist_train.sh projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py 8 --sync_bn torch
```
## Results and models ## Results and models
### Waymo ### Waymo
| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download | | Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :------: | | :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :----: | :-----: | :----: | :---------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) | ✓ | × | | | 75.2 | 72.2 | 68.9 | 66.1 | | | [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) | ✓ | × | 75.5 | 72.4 | 69.2 | 66.3 | \[log\](\<https://download.openmmlab.com/mmdetection3d/v1.1.0_models/dsvt/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class_20230917_102130.log) |
**Note**:
- `ResSECOND` denotes the base block in SECOND has residual layers.
**Note** that `ResSECOND` denotes the base block in SECOND has residual layers. - Regrettably, we are unable to provide the pre-trained model weights due to [Waymo Dataset License Agreement](https://waymo.com/open/terms/), so we only provide the training logs as shown above.
## Citation ## Citation
......
...@@ -88,25 +88,28 @@ model = dict( ...@@ -88,25 +88,28 @@ model = dict(
loss_cls=dict( loss_cls=dict(
type='mmdet.GaussianFocalLoss', reduction='mean', loss_weight=1.0), type='mmdet.GaussianFocalLoss', reduction='mean', loss_weight=1.0),
loss_bbox=dict(type='mmdet.L1Loss', reduction='mean', loss_weight=2.0), loss_bbox=dict(type='mmdet.L1Loss', reduction='mean', loss_weight=2.0),
loss_iou=dict(type='mmdet.L1Loss', reduction='sum', loss_weight=1.0),
loss_reg_iou=dict(
type='mmdet3d.DIoU3DLoss', reduction='mean', loss_weight=2.0),
norm_bbox=True), norm_bbox=True),
# model training and testing settings # model training and testing settings
train_cfg=dict( train_cfg=dict(
pts=dict( grid_size=grid_size,
grid_size=grid_size, voxel_size=voxel_size,
voxel_size=voxel_size, point_cloud_range=point_cloud_range,
out_size_factor=4, out_size_factor=1,
dense_reg=1, dense_reg=1,
gaussian_overlap=0.1, gaussian_overlap=0.1,
max_objs=500, max_objs=500,
min_radius=2, min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])), code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
test_cfg=dict( test_cfg=dict(
max_per_img=500, max_per_img=500,
max_pool_nms=False, max_pool_nms=False,
min_radius=[4, 12, 10, 1, 0.85, 0.175], min_radius=[4, 12, 10, 1, 0.85, 0.175],
iou_rectifier=[[0.68, 0.71, 0.65]], iou_rectifier=[[0.68, 0.71, 0.65]],
pc_range=[-80, -80], pc_range=[-80, -80],
out_size_factor=4, out_size_factor=1,
voxel_size=voxel_size[:2], voxel_size=voxel_size[:2],
nms_type='rotate', nms_type='rotate',
multi_class_nms=True, multi_class_nms=True,
...@@ -128,6 +131,8 @@ db_sampler = dict( ...@@ -128,6 +131,8 @@ db_sampler = dict(
coord_type='LIDAR', coord_type='LIDAR',
load_dim=6, load_dim=6,
use_dim=[0, 1, 2, 3, 4], use_dim=[0, 1, 2, 3, 4],
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args), backend_args=backend_args),
backend_args=backend_args) backend_args=backend_args)
...@@ -138,25 +143,22 @@ train_pipeline = [ ...@@ -138,25 +143,22 @@ train_pipeline = [
load_dim=6, load_dim=6,
use_dim=5, use_dim=5,
norm_intensity=True, norm_intensity=True,
norm_elongation=True,
backend_args=backend_args), backend_args=backend_args),
# Add this if using `MultiFrameDeformableDecoderRPN`
# dict(
# type='LoadPointsFromMultiSweeps',
# sweeps_num=9,
# load_dim=6,
# use_dim=[0, 1, 2, 3, 4],
# pad_empty_sweeps=True,
# remove_close=True),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler), dict(type='ObjectSample', db_sampler=db_sampler),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict( dict(
type='GlobalRotScaleTrans', type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05], scale_ratio_range=[0.95, 1.05],
translation_std=[0.5, 0.5, 0]), translation_std=[0.5, 0.5, 0.5]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter3D', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict( dict(
type='Pack3DDetInputs', type='Pack3DDetInputs',
...@@ -172,25 +174,34 @@ test_pipeline = [ ...@@ -172,25 +174,34 @@ test_pipeline = [
norm_intensity=True, norm_intensity=True,
norm_elongation=True, norm_elongation=True,
backend_args=backend_args), backend_args=backend_args),
dict(type='PointsRangeFilter3D', point_cloud_range=point_cloud_range),
dict( dict(
type='MultiScaleFlipAug3D', type='Pack3DDetInputs',
img_scale=(1333, 800), keys=['points'],
pts_scale_ratio=1, meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
] ]
dataset_type = 'WaymoDataset' dataset_type = 'WaymoDataset'
train_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='waymo_infos_train.pkl',
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
pipeline=train_pipeline,
modality=input_modality,
test_mode=False,
metainfo=metainfo,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR',
# load one frame every five frames
load_interval=5,
backend_args=backend_args))
val_dataloader = dict( val_dataloader = dict(
batch_size=4, batch_size=4,
num_workers=4, num_workers=4,
...@@ -212,18 +223,59 @@ test_dataloader = val_dataloader ...@@ -212,18 +223,59 @@ test_dataloader = val_dataloader
val_evaluator = dict( val_evaluator = dict(
type='WaymoMetric', type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin', waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format', result_prefix='./dsvt_pred')
backend_args=backend_args,
convert_kitti_format=False,
idx2metainfo='./data/waymo/waymo_format/idx2metainfo.pkl')
test_evaluator = val_evaluator test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')] vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict( visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer') type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
# schedules
lr = 1e-5
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.05, betas=(0.9, 0.99)),
clip_grad=dict(max_norm=10, norm_type=2))
param_scheduler = [
dict(
type='CosineAnnealingLR',
T_max=1.2,
eta_min=lr * 100,
begin=0,
end=1.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=10.8,
eta_min=lr * 1e-4,
begin=1.2,
end=12,
by_epoch=True,
convert_to_iter_based=True),
# momentum scheduler
dict(
type='CosineAnnealingMomentum',
T_max=1.2,
eta_min=0.85,
begin=0,
end=1.2,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
T_max=10.8,
eta_min=0.95,
begin=1.2,
end=12,
by_epoch=True,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=12, val_interval=1)
# runtime settings # runtime settings
val_cfg = dict() val_cfg = dict()
test_cfg = dict() test_cfg = dict()
...@@ -236,4 +288,12 @@ test_cfg = dict() ...@@ -236,4 +288,12 @@ test_cfg = dict()
default_hooks = dict( default_hooks = dict(
logger=dict(type='LoggerHook', interval=50), logger=dict(type='LoggerHook', interval=50),
checkpoint=dict(type='CheckpointHook', interval=5)) checkpoint=dict(type='CheckpointHook', interval=1))
custom_hooks = [
dict(
type='DisableAugHook',
disable_after_epoch=11,
disable_aug_list=[
'GlobalRotScaleTrans', 'RandomFlip3D', 'ObjectSample'
])
]
from .disable_aug_hook import DisableAugHook
from .dsvt import DSVT from .dsvt import DSVT
from .dsvt_head import DSVTCenterHead from .dsvt_head import DSVTCenterHead
from .dsvt_transformer import DSVTMiddleEncoder from .dsvt_transformer import DSVTMiddleEncoder
from .dynamic_pillar_vfe import DynamicPillarVFE3D from .dynamic_pillar_vfe import DynamicPillarVFE3D
from .map2bev import PointPillarsScatter3D from .map2bev import PointPillarsScatter3D
from .res_second import ResSECOND from .res_second import ResSECOND
from .transforms_3d import ObjectRangeFilter3D, PointsRangeFilter3D
from .utils import DSVTBBoxCoder from .utils import DSVTBBoxCoder
__all__ = [ __all__ = [
'DSVTCenterHead', 'DSVT', 'DSVTMiddleEncoder', 'DynamicPillarVFE3D', 'DSVTCenterHead', 'DSVT', 'DSVTMiddleEncoder', 'DynamicPillarVFE3D',
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder' 'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder',
'ObjectRangeFilter3D', 'PointsRangeFilter3D', 'DisableAugHook'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmengine.dataset import BaseDataset
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner
from mmdet3d.registry import HOOKS
@HOOKS.register_module()
class DisableAugHook(Hook):
"""The hook of disabling augmentations during training.
Args:
disable_after_epoch (int): The number of epochs after which
the data augmentation will be closed in the training.
Defaults to 15.
disable_aug_list (list): the list of data augmentation will
be closed in the training. Defaults to [].
"""
def __init__(self,
disable_after_epoch: int = 15,
disable_aug_list: List = []):
self.disable_after_epoch = disable_after_epoch
self.disable_aug_list = disable_aug_list
self._restart_dataloader = False
def before_train_epoch(self, runner: Runner):
"""Close augmentation.
Args:
runner (Runner): The runner.
"""
epoch = runner.epoch
train_loader = runner.train_dataloader
model = runner.model
# TODO: refactor after mmengine using model wrapper
if is_model_wrapper(model):
model = model.module
if epoch == self.disable_after_epoch:
dataset = runner.train_dataloader.dataset
# handle dataset wrapper
if not isinstance(dataset, BaseDataset):
dataset = dataset.dataset
new_transforms = []
for transform in dataset.pipeline.transforms: # noqa: E501
if transform.__class__.__name__ not in self.disable_aug_list:
new_transforms.append(transform)
else:
runner.logger.info(
f'Disable {transform.__class__.__name__}')
dataset.pipeline.transforms = new_transforms
# The dataset pipeline cannot be updated when persistent_workers
# is True, so we need to force the dataloader's multi-process
# restart. This is a very hacky approach.
if hasattr(train_loader, 'persistent_workers'
) and train_loader.persistent_workers is True:
train_loader._DataLoader__initialized = False
train_loader._iterator = None
self._restart_dataloader = True
else:
# Once the restart is complete, we need to restore
# the initialization flag.
if self._restart_dataloader:
train_loader._DataLoader__initialized = True
...@@ -103,7 +103,11 @@ class DSVT(Base3DDetector): ...@@ -103,7 +103,11 @@ class DSVT(Base3DDetector):
Returns: Returns:
dict[str, Tensor]: A dictionary of loss components. dict[str, Tensor]: A dictionary of loss components.
""" """
pass pts_feats = self.extract_feat(batch_inputs_dict)
losses = dict()
loss = self.bbox_head.loss(pts_feats, batch_data_samples)
losses.update(loss)
return losses
def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]], def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample], batch_data_samples: List[Det3DDataSample],
......
import math
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
import torch.nn as nn
from mmcv.ops import boxes_iou3d
from mmdet.models.utils import multi_apply from mmdet.models.utils import multi_apply
from mmengine.model import kaiming_init
from mmengine.structures import InstanceData from mmengine.structures import InstanceData
from torch import Tensor from torch import Tensor
from torch.nn.init import constant_
from mmdet3d.models import CenterHead from mmdet3d.models import CenterHead
from mmdet3d.models.layers import circle_nms, nms_bev from mmdet3d.models.layers import circle_nms, nms_bev
from mmdet3d.models.utils import (clip_sigmoid, draw_heatmap_gaussian,
gaussian_radius)
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample, xywhr2xyxyr from mmdet3d.structures import Det3DDataSample, xywhr2xyxyr
...@@ -18,8 +25,33 @@ class DSVTCenterHead(CenterHead): ...@@ -18,8 +25,33 @@ class DSVTCenterHead(CenterHead):
This head adds IoU prediction branch based on the original CenterHead. This head adds IoU prediction branch based on the original CenterHead.
""" """
def __init__(self, *args, **kwargs): def __init__(self,
loss_iou=dict(
type='mmdet.L1Loss', reduction='mean', loss_weight=1),
loss_reg_iou=None,
*args,
**kwargs):
super(DSVTCenterHead, self).__init__(*args, **kwargs) super(DSVTCenterHead, self).__init__(*args, **kwargs)
self.loss_iou = MODELS.build(loss_iou)
self.loss_iou_reg = MODELS.build(
loss_reg_iou) if loss_reg_iou is not None else None
def init_weights(self):
kaiming_init(
self.shared_conv.conv,
a=math.sqrt(5),
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform')
for head in self.task_heads[0].heads:
if head == 'heatmap':
constant_(self.task_heads[0].__getattr__(head)[-1].bias,
self.task_heads[0].init_bias)
else:
for m in self.task_heads[0].__getattr__(head).modules():
if isinstance(m, nn.Conv2d):
kaiming_init(
m, mode='fan_in', nonlinearity='leaky_relu')
def forward_single(self, x: Tensor) -> dict: def forward_single(self, x: Tensor) -> dict:
"""Forward function for CenterPoint. """Forward function for CenterPoint.
...@@ -66,7 +98,298 @@ class DSVTCenterHead(CenterHead): ...@@ -66,7 +98,298 @@ class DSVTCenterHead(CenterHead):
Returns: Returns:
dict: Losses of each branch. dict: Losses of each branch.
""" """
pass outs = self(pts_feats)
batch_gt_instance_3d = []
for data_sample in batch_data_samples:
batch_gt_instance_3d.append(data_sample.gt_instances_3d)
losses = self.loss_by_feat(outs, batch_gt_instance_3d)
return losses
def _decode_all_preds(self,
pred_dict,
point_cloud_range=None,
voxel_size=None):
batch_size, _, H, W = pred_dict['reg'].shape
batch_center = pred_dict['reg'].permute(0, 2, 3, 1).contiguous().view(
batch_size, H * W, 2) # (B, H, W, 2)
batch_center_z = pred_dict['height'].permute(
0, 2, 3, 1).contiguous().view(batch_size, H * W, 1) # (B, H, W, 1)
batch_dim = pred_dict['dim'].exp().permute(
0, 2, 3, 1).contiguous().view(batch_size, H * W, 3) # (B, H, W, 3)
batch_rot_cos = pred_dict['rot'][:, 0].unsqueeze(dim=1).permute(
0, 2, 3, 1).contiguous().view(batch_size, H * W, 1) # (B, H, W, 1)
batch_rot_sin = pred_dict['rot'][:, 1].unsqueeze(dim=1).permute(
0, 2, 3, 1).contiguous().view(batch_size, H * W, 1) # (B, H, W, 1)
batch_vel = pred_dict['vel'].permute(0, 2, 3, 1).contiguous().view(
batch_size, H * W, 2) if 'vel' in pred_dict.keys() else None
angle = torch.atan2(batch_rot_sin, batch_rot_cos) # (B, H*W, 1)
ys, xs = torch.meshgrid([
torch.arange(
0, H, device=batch_center.device, dtype=batch_center.dtype),
torch.arange(
0, W, device=batch_center.device, dtype=batch_center.dtype)
])
ys = ys.view(1, H, W).repeat(batch_size, 1, 1)
xs = xs.view(1, H, W).repeat(batch_size, 1, 1)
xs = xs.view(batch_size, -1, 1) + batch_center[:, :, 0:1]
ys = ys.view(batch_size, -1, 1) + batch_center[:, :, 1:2]
xs = xs * voxel_size[0] + point_cloud_range[0]
ys = ys * voxel_size[1] + point_cloud_range[1]
box_part_list = [xs, ys, batch_center_z, batch_dim, angle]
if batch_vel is not None:
box_part_list.append(batch_vel)
box_preds = torch.cat((box_part_list),
dim=-1).view(batch_size, H, W, -1)
return box_preds
def _transpose_and_gather_feat(self, feat, ind):
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = self._gather_feat(feat, ind)
return feat
def calc_iou_loss(self, iou_preds, batch_box_preds, mask, ind, gt_boxes):
"""
Args:
iou_preds: (batch x 1 x h x w)
batch_box_preds: (batch x (7 or 9) x h x w)
mask: (batch x max_objects)
ind: (batch x max_objects)
gt_boxes: List of batch groundtruth boxes.
Returns:
Tensor: IoU Loss.
"""
if mask.sum() == 0:
return iou_preds.new_zeros((1))
mask = mask.bool()
selected_iou_preds = self._transpose_and_gather_feat(iou_preds,
ind)[mask]
selected_box_preds = self._transpose_and_gather_feat(
batch_box_preds, ind)[mask]
gt_boxes = torch.cat(gt_boxes)
assert gt_boxes.size(0) == selected_box_preds.size(0)
iou_target = boxes_iou3d(selected_box_preds[:, 0:7], gt_boxes[:, 0:7])
iou_target = torch.diag(iou_target).view(-1)
iou_target = iou_target * 2 - 1 # [0, 1] ==> [-1, 1]
loss = self.loss_iou(selected_iou_preds.view(-1), iou_target)
loss = loss / torch.clamp(mask.sum(), min=1e-4)
return loss
def calc_iou_reg_loss(self, batch_box_preds, mask, ind, gt_boxes):
if mask.sum() == 0:
return batch_box_preds.new_zeros((1))
mask = mask.bool()
selected_box_preds = self._transpose_and_gather_feat(
batch_box_preds, ind)[mask]
gt_boxes = torch.cat(gt_boxes)
assert gt_boxes.size(0) == selected_box_preds.size(0)
loss = self.loss_iou_reg(selected_box_preds[:, 0:7], gt_boxes[:, 0:7])
return loss
def get_targets(
self,
batch_gt_instances_3d: List[InstanceData],
) -> Tuple[List[Tensor]]:
"""Generate targets.
How each output is transformed:
Each nested list is transposed so that all same-index elements in
each sub-list (1, ..., N) become the new sub-lists.
[ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ]
==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ]
The new transposed nested list is converted into a list of N
tensors generated by concatenating tensors in the new sub-lists.
[ tensor0, tensor1, tensor2, ... ]
Args:
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and\
``labels_3d`` attributes.
Returns:
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the
position of the valid boxes.
- list[torch.Tensor]: Masks indicating which
boxes are valid.
"""
heatmaps, anno_boxes, inds, masks, task_gt_bboxes = multi_apply(
self.get_targets_single, batch_gt_instances_3d)
# Transpose heatmaps
heatmaps = list(map(list, zip(*heatmaps)))
heatmaps = [torch.stack(hms_) for hms_ in heatmaps]
# Transpose anno_boxes
anno_boxes = list(map(list, zip(*anno_boxes)))
anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes]
# Transpose inds
inds = list(map(list, zip(*inds)))
inds = [torch.stack(inds_) for inds_ in inds]
# Transpose masks
masks = list(map(list, zip(*masks)))
masks = [torch.stack(masks_) for masks_ in masks]
# Transpose task_gt_bboxes
task_gt_bboxes = list(map(list, zip(*task_gt_bboxes)))
return heatmaps, anno_boxes, inds, masks, task_gt_bboxes
def get_targets_single(self,
gt_instances_3d: InstanceData) -> Tuple[Tensor]:
"""Generate training targets for a single sample.
Args:
gt_instances_3d (:obj:`InstanceData`): Gt_instances_3d of
single data sample. It usually includes
``bboxes_3d`` and ``labels_3d`` attributes.
Returns:
tuple[list[torch.Tensor]]: Tuple of target including
the following results in order.
- list[torch.Tensor]: Heatmap scores.
- list[torch.Tensor]: Ground truth boxes.
- list[torch.Tensor]: Indexes indicating the position
of the valid boxes.
- list[torch.Tensor]: Masks indicating which boxes
are valid.
"""
gt_labels_3d = gt_instances_3d.labels_3d
gt_bboxes_3d = gt_instances_3d.bboxes_3d
device = gt_labels_3d.device
gt_bboxes_3d = torch.cat(
(gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]),
dim=1).to(device)
max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg']
grid_size = torch.tensor(self.train_cfg['grid_size']).to(device)
pc_range = torch.tensor(self.train_cfg['point_cloud_range'])
voxel_size = torch.tensor(self.train_cfg['voxel_size'])
feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor']
# reorganize the gt_dict by tasks
task_masks = []
flag = 0
for class_name in self.class_names:
task_masks.append([
torch.where(gt_labels_3d == class_name.index(i) + flag)
for i in class_name
])
flag += len(class_name)
task_boxes = []
task_classes = []
flag2 = 0
for idx, mask in enumerate(task_masks):
task_box = []
task_class = []
for m in mask:
task_box.append(gt_bboxes_3d[m])
# 0 is background for each task, so we need to add 1 here.
task_class.append(gt_labels_3d[m] + 1 - flag2)
task_boxes.append(torch.cat(task_box, axis=0).to(device))
task_classes.append(torch.cat(task_class).long().to(device))
flag2 += len(mask)
draw_gaussian = draw_heatmap_gaussian
heatmaps, anno_boxes, inds, masks = [], [], [], []
for idx, task_head in enumerate(self.task_heads):
heatmap = gt_bboxes_3d.new_zeros(
(len(self.class_names[idx]), feature_map_size[1],
feature_map_size[0]))
anno_box = gt_bboxes_3d.new_zeros((max_objs, 8),
dtype=torch.float32)
ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64)
mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8)
num_objs = min(task_boxes[idx].shape[0], max_objs)
for k in range(num_objs):
cls_id = task_classes[idx][k] - 1
length = task_boxes[idx][k][3]
width = task_boxes[idx][k][4]
length = length / voxel_size[0] / self.train_cfg[
'out_size_factor']
width = width / voxel_size[1] / self.train_cfg[
'out_size_factor']
if width > 0 and length > 0:
radius = gaussian_radius(
(width, length),
min_overlap=self.train_cfg['gaussian_overlap'])
radius = max(self.train_cfg['min_radius'], int(radius))
# be really careful for the coordinate system of
# your box annotation.
x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][
1], task_boxes[idx][k][2]
coor_x = (
x - pc_range[0]
) / voxel_size[0] / self.train_cfg['out_size_factor']
coor_y = (
y - pc_range[1]
) / voxel_size[1] / self.train_cfg['out_size_factor']
center = torch.tensor([coor_x, coor_y],
dtype=torch.float32,
device=device)
center_int = center.to(torch.int32)
# throw out not in range objects to avoid out of array
# area when creating the heatmap
if not (0 <= center_int[0] < feature_map_size[0]
and 0 <= center_int[1] < feature_map_size[1]):
continue
draw_gaussian(heatmap[cls_id], center_int, radius)
new_idx = k
x, y = center_int[0], center_int[1]
assert (y * feature_map_size[0] + x <
feature_map_size[0] * feature_map_size[1])
ind[new_idx] = y * feature_map_size[0] + x
mask[new_idx] = 1
# TODO: support other outdoor dataset
rot = task_boxes[idx][k][6]
box_dim = task_boxes[idx][k][3:6]
if self.norm_bbox:
box_dim = box_dim.log()
anno_box[new_idx] = torch.cat([
center - torch.tensor([x, y], device=device),
z.unsqueeze(0), box_dim,
torch.cos(rot).unsqueeze(0),
torch.sin(rot).unsqueeze(0)
])
heatmaps.append(heatmap)
anno_boxes.append(anno_box)
masks.append(mask)
inds.append(ind)
return heatmaps, anno_boxes, inds, masks, task_boxes
def loss_by_feat(self, preds_dicts: Tuple[List[dict]], def loss_by_feat(self, preds_dicts: Tuple[List[dict]],
batch_gt_instances_3d: List[InstanceData], *args, batch_gt_instances_3d: List[InstanceData], *args,
...@@ -79,13 +402,72 @@ class DSVTCenterHead(CenterHead): ...@@ -79,13 +402,72 @@ class DSVTCenterHead(CenterHead):
tasks head, and the internal list indicate different tasks head, and the internal list indicate different
FPN level. FPN level.
batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of
gt_instances. It usually includes ``bboxes_3d`` and\ gt_instances_3d. It usually includes ``bboxes_3d`` and
``labels_3d`` attributes. ``labels_3d`` attributes.
Returns: Returns:
dict[str,torch.Tensor]: Loss of heatmap and bbox of each task. dict[str,torch.Tensor]: Loss of heatmap and bbox of each task.
""" """
pass heatmaps, anno_boxes, inds, masks, task_gt_bboxes = self.get_targets(
batch_gt_instances_3d)
loss_dict = dict()
for task_id, preds_dict in enumerate(preds_dicts):
# heatmap focal loss
preds_dict[0]['heatmap'] = clip_sigmoid(preds_dict[0]['heatmap'])
num_pos = heatmaps[task_id].eq(1).float().sum().item()
loss_heatmap = self.loss_cls(
preds_dict[0]['heatmap'],
heatmaps[task_id],
avg_factor=max(num_pos, 1))
target_box = anno_boxes[task_id]
# reconstruct the anno_box from multiple reg heads
preds_dict[0]['anno_box'] = torch.cat(
(preds_dict[0]['reg'], preds_dict[0]['height'],
preds_dict[0]['dim'], preds_dict[0]['rot']),
dim=1)
# Regression loss for dimension, offset, height, rotation
ind = inds[task_id]
num = masks[task_id].float().sum()
pred = preds_dict[0]['anno_box'].permute(0, 2, 3, 1).contiguous()
pred = pred.view(pred.size(0), -1, pred.size(3))
pred = self._gather_feat(pred, ind)
mask = masks[task_id].unsqueeze(2).expand_as(target_box).float()
isnotnan = (~torch.isnan(target_box)).float()
mask *= isnotnan
code_weights = self.train_cfg.get('code_weights', None)
bbox_weights = mask * mask.new_tensor(code_weights)
loss_bbox = self.loss_bbox(
pred, target_box, bbox_weights, avg_factor=(num + 1e-4))
loss_dict[f'task{task_id}.loss_heatmap'] = loss_heatmap
loss_dict[f'task{task_id}.loss_bbox'] = loss_bbox
if 'iou' in preds_dict[0]:
batch_box_preds = self._decode_all_preds(
pred_dict=preds_dict[0],
point_cloud_range=self.train_cfg['point_cloud_range'],
voxel_size=self.train_cfg['voxel_size']
) # (B, H, W, 7 or 9)
batch_box_preds_for_iou = batch_box_preds.permute(
0, 3, 1, 2) # (B, 7 or 9, H, W)
loss_dict[f'task{task_id}.loss_iou'] = self.calc_iou_loss(
iou_preds=preds_dict[0]['iou'],
batch_box_preds=batch_box_preds_for_iou.clone().detach(),
mask=masks[task_id],
ind=ind,
gt_boxes=task_gt_bboxes[task_id])
if self.loss_iou_reg is not None:
loss_dict[f'task{task_id}.loss_reg_iou'] = \
self.calc_iou_reg_loss(
batch_box_preds=batch_box_preds_for_iou,
mask=masks[task_id],
ind=ind,
gt_boxes=task_gt_bboxes[task_id])
return loss_dict
def predict(self, def predict(self,
pts_feats: Tuple[torch.Tensor], pts_feats: Tuple[torch.Tensor],
...@@ -158,6 +540,7 @@ class DSVTCenterHead(CenterHead): ...@@ -158,6 +540,7 @@ class DSVTCenterHead(CenterHead):
else: else:
batch_dim = preds_dict[0]['dim'] batch_dim = preds_dict[0]['dim']
# It's different from CenterHead
batch_rotc = preds_dict[0]['rot'][:, 0].unsqueeze(1) batch_rotc = preds_dict[0]['rot'][:, 0].unsqueeze(1)
batch_rots = preds_dict[0]['rot'][:, 1].unsqueeze(1) batch_rots = preds_dict[0]['rot'][:, 1].unsqueeze(1)
batch_iou = (preds_dict[0]['iou'] + batch_iou = (preds_dict[0]['iou'] +
......
# modified from https://github.com/Haiyang-W/DSVT # modified from https://github.com/Haiyang-W/DSVT
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_scatter import torch_scatter
...@@ -76,6 +77,7 @@ class DynamicPillarVFE3D(nn.Module): ...@@ -76,6 +77,7 @@ class DynamicPillarVFE3D(nn.Module):
self.voxel_x = voxel_size[0] self.voxel_x = voxel_size[0]
self.voxel_y = voxel_size[1] self.voxel_y = voxel_size[1]
self.voxel_z = voxel_size[2] self.voxel_z = voxel_size[2]
point_cloud_range = np.array(point_cloud_range).astype(np.float32)
self.x_offset = self.voxel_x / 2 + point_cloud_range[0] self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
self.y_offset = self.voxel_y / 2 + point_cloud_range[1] self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
self.z_offset = self.voxel_z / 2 + point_cloud_range[2] self.z_offset = self.voxel_z / 2 + point_cloud_range[2]
......
# modified from https://github.com/Haiyang-W/DSVT # modified from https://github.com/Haiyang-W/DSVT
from typing import Sequence, Tuple
import warnings
from typing import Optional, Sequence, Tuple
from mmengine.model import BaseModule from mmengine.model import BaseModule
from torch import Tensor from torch import Tensor
...@@ -78,8 +76,8 @@ class ResSECOND(BaseModule): ...@@ -78,8 +76,8 @@ class ResSECOND(BaseModule):
out_channels (list[int]): Output channels for multi-scale feature maps. out_channels (list[int]): Output channels for multi-scale feature maps.
blocks_nums (list[int]): Number of blocks in each stage. blocks_nums (list[int]): Number of blocks in each stage.
layer_strides (list[int]): Strides of each stage. layer_strides (list[int]): Strides of each stage.
norm_cfg (dict): Config dict of normalization layers. init_cfg (dict, optional): Config for weight initialization.
conv_cfg (dict): Config dict of convolutional layers. Defaults to None.
""" """
def __init__(self, def __init__(self,
...@@ -87,8 +85,7 @@ class ResSECOND(BaseModule): ...@@ -87,8 +85,7 @@ class ResSECOND(BaseModule):
out_channels: Sequence[int] = [128, 128, 256], out_channels: Sequence[int] = [128, 128, 256],
blocks_nums: Sequence[int] = [1, 2, 2], blocks_nums: Sequence[int] = [1, 2, 2],
layer_strides: Sequence[int] = [2, 2, 2], layer_strides: Sequence[int] = [2, 2, 2],
init_cfg: OptMultiConfig = None, init_cfg: OptMultiConfig = None) -> None:
pretrained: Optional[str] = None) -> None:
super(ResSECOND, self).__init__(init_cfg=init_cfg) super(ResSECOND, self).__init__(init_cfg=init_cfg)
assert len(layer_strides) == len(blocks_nums) assert len(layer_strides) == len(blocks_nums)
assert len(out_channels) == len(blocks_nums) assert len(out_channels) == len(blocks_nums)
...@@ -108,14 +105,6 @@ class ResSECOND(BaseModule): ...@@ -108,14 +105,6 @@ class ResSECOND(BaseModule):
BasicResBlock(out_channels[i], out_channels[i])) BasicResBlock(out_channels[i], out_channels[i]))
blocks.append(nn.Sequential(*cur_layers)) blocks.append(nn.Sequential(*cur_layers))
self.blocks = nn.Sequential(*blocks) self.blocks = nn.Sequential(*blocks)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
else:
self.init_cfg = dict(type='Kaiming', layer='Conv2d')
def forward(self, x: Tensor) -> Tuple[Tensor, ...]: def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
"""Forward function. """Forward function.
......
from typing import List
import numpy as np
from mmcv import BaseTransform
from mmdet3d.registry import TRANSFORMS
@TRANSFORMS.register_module()
class ObjectRangeFilter3D(BaseTransform):
"""Filter objects by the range. It differs from `ObjectRangeFilter` by
using `in_range_3d` instead of `in_range_bev`.
Required Keys:
- gt_bboxes_3d
Modified Keys:
- gt_bboxes_3d
Args:
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range: List[float]) -> None:
self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def transform(self, input_dict: dict) -> dict:
"""Transform function to filter objects by the range.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
keys are updated in the result dict.
"""
gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d']
mask = gt_bboxes_3d.in_range_3d(self.pcd_range)
gt_bboxes_3d = gt_bboxes_3d[mask]
# mask is a torch tensor but gt_labels_3d is still numpy array
# using mask to index gt_labels_3d will cause bug when
# len(gt_labels_3d) == 1, where mask=1 will be interpreted
# as gt_labels_3d[1] and cause out of index error
gt_labels_3d = gt_labels_3d[mask.numpy().astype(bool)]
# limit rad to [-pi, pi]
gt_bboxes_3d.limit_yaw(offset=0.5, period=2 * np.pi)
input_dict['gt_bboxes_3d'] = gt_bboxes_3d
input_dict['gt_labels_3d'] = gt_labels_3d
return input_dict
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(point_cloud_range={self.pcd_range.tolist()})'
return repr_str
@TRANSFORMS.register_module()
class PointsRangeFilter3D(BaseTransform):
"""Filter points by the range. It differs from `PointRangeFilter` by using
`in_range_bev` instead of `in_range_3d`.
Required Keys:
- points
- pts_instance_mask (optional)
Modified Keys:
- points
- pts_instance_mask (optional)
Args:
point_cloud_range (list[float]): Point cloud range.
"""
def __init__(self, point_cloud_range: List[float]) -> None:
self.pcd_range = np.array(point_cloud_range, dtype=np.float32)
def transform(self, input_dict: dict) -> dict:
"""Transform function to filter points by the range.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict.
"""
points = input_dict['points']
points_mask = points.in_range_bev(self.pcd_range[[0, 1, 3, 4]])
clean_points = points[points_mask]
input_dict['points'] = clean_points
points_mask = points_mask.numpy()
pts_instance_mask = input_dict.get('pts_instance_mask', None)
pts_semantic_mask = input_dict.get('pts_semantic_mask', None)
if pts_instance_mask is not None:
input_dict['pts_instance_mask'] = pts_instance_mask[points_mask]
if pts_semantic_mask is not None:
input_dict['pts_semantic_mask'] = pts_semantic_mask[points_mask]
return input_dict
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(point_cloud_range={self.pcd_range.tolist()})'
return repr_str
...@@ -3,10 +3,11 @@ from typing import Dict, List, Optional ...@@ -3,10 +3,11 @@ from typing import Dict, List, Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmdet.models.losses.utils import weighted_loss
from torch import Tensor from torch import Tensor
from mmdet3d.models.task_modules import CenterPointBBoxCoder from mmdet3d.models.task_modules import CenterPointBBoxCoder
from mmdet3d.registry import TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from .ops.ingroup_inds.ingroup_inds_op import ingroup_inds from .ops.ingroup_inds.ingroup_inds_op import ingroup_inds
get_inner_win_inds_cuda = ingroup_inds get_inner_win_inds_cuda = ingroup_inds
...@@ -266,7 +267,7 @@ class DSVTBBoxCoder(CenterPointBBoxCoder): ...@@ -266,7 +267,7 @@ class DSVTBBoxCoder(CenterPointBBoxCoder):
thresh_mask = final_scores > self.score_threshold thresh_mask = final_scores > self.score_threshold
if self.post_center_range is not None: if self.post_center_range is not None:
self.post_center_range = torch.tensor( self.post_center_range = torch.as_tensor(
self.post_center_range, device=heat.device) self.post_center_range, device=heat.device)
mask = (final_box_preds[..., :3] >= mask = (final_box_preds[..., :3] >=
self.post_center_range[:3]).all(2) self.post_center_range[:3]).all(2)
...@@ -298,3 +299,142 @@ class DSVTBBoxCoder(CenterPointBBoxCoder): ...@@ -298,3 +299,142 @@ class DSVTBBoxCoder(CenterPointBBoxCoder):
'support post_center_range is not None for now!') 'support post_center_range is not None for now!')
return predictions_dicts return predictions_dicts
def center_to_corner2d(center, dim):
corners_norm = torch.tensor(
[[-0.5, -0.5], [-0.5, 0.5], [0.5, 0.5], [0.5, -0.5]],
device=dim.device).type_as(center) # (4, 2)
corners = dim.view([-1, 1, 2]) * corners_norm.view([1, 4, 2]) # (N, 4, 2)
corners = corners + center.view(-1, 1, 2)
return corners
@weighted_loss
def diou3d_loss(pred_boxes, gt_boxes, eps: float = 1e-7):
"""
modified from https://github.com/agent-sgs/PillarNet/blob/master/det3d/core/utils/center_utils.py # noqa
Args:
pred_boxes (N, 7):
gt_boxes (N, 7):
Returns:
Tensor: Distance-IoU Loss.
"""
assert pred_boxes.shape[0] == gt_boxes.shape[0]
qcorners = center_to_corner2d(pred_boxes[:, :2],
pred_boxes[:, 3:5]) # (N, 4, 2)
gcorners = center_to_corner2d(gt_boxes[:, :2], gt_boxes[:,
3:5]) # (N, 4, 2)
inter_max_xy = torch.minimum(qcorners[:, 2], gcorners[:, 2])
inter_min_xy = torch.maximum(qcorners[:, 0], gcorners[:, 0])
out_max_xy = torch.maximum(qcorners[:, 2], gcorners[:, 2])
out_min_xy = torch.minimum(qcorners[:, 0], gcorners[:, 0])
# calculate area
volume_pred_boxes = pred_boxes[:, 3] * pred_boxes[:, 4] * pred_boxes[:, 5]
volume_gt_boxes = gt_boxes[:, 3] * gt_boxes[:, 4] * gt_boxes[:, 5]
inter_h = torch.minimum(
pred_boxes[:, 2] + 0.5 * pred_boxes[:, 5],
gt_boxes[:, 2] + 0.5 * gt_boxes[:, 5]) - torch.maximum(
pred_boxes[:, 2] - 0.5 * pred_boxes[:, 5],
gt_boxes[:, 2] - 0.5 * gt_boxes[:, 5])
inter_h = torch.clamp(inter_h, min=0)
inter = torch.clamp((inter_max_xy - inter_min_xy), min=0)
volume_inter = inter[:, 0] * inter[:, 1] * inter_h
volume_union = volume_gt_boxes + volume_pred_boxes - volume_inter + eps
# boxes_iou3d_gpu(pred_boxes, gt_boxes)
inter_diag = torch.pow(gt_boxes[:, 0:3] - pred_boxes[:, 0:3], 2).sum(-1)
outer_h = torch.maximum(
gt_boxes[:, 2] + 0.5 * gt_boxes[:, 5],
pred_boxes[:, 2] + 0.5 * pred_boxes[:, 5]) - torch.minimum(
gt_boxes[:, 2] - 0.5 * gt_boxes[:, 5],
pred_boxes[:, 2] - 0.5 * pred_boxes[:, 5])
outer_h = torch.clamp(outer_h, min=0)
outer = torch.clamp((out_max_xy - out_min_xy), min=0)
outer_diag = outer[:, 0]**2 + outer[:, 1]**2 + outer_h**2 + eps
dious = volume_inter / volume_union - inter_diag / outer_diag
dious = torch.clamp(dious, min=-1.0, max=1.0)
loss = 1 - dious
return loss
@MODELS.register_module()
class DIoU3DLoss(nn.Module):
r"""3D bboxes Implementation of `Distance-IoU Loss: Faster and Better
Learning for Bounding Box Regression <https://arxiv.org/abs/1911.08287>`_.
Code is modified from https://github.com/Zzh-tju/DIoU.
Args:
eps (float): Epsilon to avoid log(0). Defaults to 1e-6.
reduction (str): Options are "none", "mean" and "sum".
Defaults to "mean".
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(self,
eps: float = 1e-6,
reduction: str = 'mean',
loss_weight: float = 1.0) -> None:
super().__init__()
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
avg_factor: Optional[int] = None,
reduction_override: Optional[str] = None,
**kwargs) -> Tensor:
"""Forward function.
Args:
pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
shape (n, 4).
target (Tensor): The learning target of the prediction,
shape (n, 4).
weight (Optional[Tensor], optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (Optional[int], optional): Average factor that is used
to average the loss. Defaults to None.
reduction_override (Optional[str], optional): The reduction method
used to override the original reduction method of the loss.
Defaults to None. Options are "none", "mean" and "sum".
Returns:
Tensor: Loss tensor.
"""
if weight is not None and not torch.any(weight > 0):
if pred.dim() == weight.dim() + 1:
weight = weight.unsqueeze(1)
return (pred * weight).sum() # 0
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if weight is not None and weight.dim() > 1:
# TODO: remove this in the future
# reduce the weight of shape (n, 4) to (n,) to match the
# giou_loss of shape (n,)
assert weight.shape == pred.shape
weight = weight.mean(-1)
loss = self.loss_weight * diou3d_loss(
pred,
target,
weight,
eps=self.eps,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss
...@@ -21,6 +21,12 @@ def parse_args(): ...@@ -21,6 +21,12 @@ def parse_args():
action='store_true', action='store_true',
default=False, default=False,
help='enable automatic-mixed-precision training') help='enable automatic-mixed-precision training')
parser.add_argument(
'--sync_bn',
choices=['none', 'torch', 'mmcv'],
default='none',
help='convert all BatchNorm layers in the model to SyncBatchNorm '
'(SyncBN) or mmcv.ops.sync_bn.SyncBatchNorm (MMSyncBN) layers.')
parser.add_argument( parser.add_argument(
'--auto-scale-lr', '--auto-scale-lr',
action='store_true', action='store_true',
...@@ -98,6 +104,10 @@ def main(): ...@@ -98,6 +104,10 @@ def main():
cfg.optim_wrapper.type = 'AmpOptimWrapper' cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic' cfg.optim_wrapper.loss_scale = 'dynamic'
# convert BatchNorm layers
if args.sync_bn != 'none':
cfg.sync_bn = args.sync_bn
# enable automatically scaling LR # enable automatically scaling LR
if args.auto_scale_lr: if args.auto_scale_lr:
if 'auto_scale_lr' in cfg and \ if 'auto_scale_lr' in cfg and \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment