Unverified Commit f7356f4b authored by twang's avatar twang Committed by GitHub
Browse files

[Feature] Support FCOS3D head (#442)

* Support base mono3d dense head and anchor free mono3d head

* Support FCOS3D head

* Support FCOS3D baseline on nuScenes

* Fix an import error caused by update of mmcv/mmdet

* Change img_scale to scale_factor in the MultiScaleFlipAug in the config

* Add pred_bbox2d in the params of anchor_free_mono3d_head

* Add unit test for fcos3d head

* Fix a minor bug when setting img_metas in the unit test

* Add unit test for fcos3d detector

* Simplify the logic of weights initialization

* Add comments to specify the reason of cloning features

* Update head config
parent a0090aa1
dataset_type = 'NuScenesMonoDataset'
data_root = 'data/nuscenes/'
class_names = [
'car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle',
'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'
]
# Input modality for nuScenes dataset, this is consistent with the submission
# format which requires the information in input_modality.
input_modality = dict(
use_lidar=False,
use_camera=True,
use_radar=False,
use_map=False,
use_external=False)
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFileMono3D'),
dict(
type='LoadAnnotations3D',
with_bbox=True,
with_label=True,
with_attr_label=True,
with_bbox_3d=True,
with_label_3d=True,
with_bbox_depth=True),
dict(type='Resize', img_scale=(1600, 900), keep_ratio=True),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
keys=[
'img', 'gt_bboxes', 'gt_labels', 'attr_labels', 'gt_bboxes_3d',
'gt_labels_3d', 'centers2d', 'depths'
]),
]
test_pipeline = [
dict(type='LoadImageFromFileMono3D'),
dict(
type='MultiScaleFlipAug',
scale_factor=1.0,
flip=False,
transforms=[
dict(type='RandomFlip3D'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_train_mono3d.coco.json',
img_prefix=data_root,
classes=class_names,
pipeline=train_pipeline,
modality=input_modality,
test_mode=False,
box_type_3d='Camera'),
val=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_val_mono3d.coco.json',
img_prefix=data_root,
classes=class_names,
pipeline=test_pipeline,
modality=input_modality,
test_mode=True,
box_type_3d='Camera'),
test=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_val_mono3d.coco.json',
img_prefix=data_root,
classes=class_names,
pipeline=test_pipeline,
modality=input_modality,
test_mode=True,
box_type_3d='Camera'))
evaluation = dict(interval=2)
model = dict(
type='FCOSMono3D',
pretrained='open-mmlab://detectron2/resnet101_caffe',
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs=True,
extra_convs_on_inputs=False, # use P5
num_outs=5,
relu_before_extra_convs=True),
bbox_head=dict(
type='FCOSMono3DHead',
num_classes=10,
in_channels=256,
stacked_convs=2,
feat_channels=256,
use_direction_classifier=True,
diff_rad_by_sin=True,
pred_attrs=True,
pred_velo=True,
dir_offset=0.7854, # pi/4
strides=[8, 16, 32, 64, 128],
group_reg_dims=(2, 1, 3, 1, 2), # offset, depth, size, rot, velo
cls_branch=(256, ),
reg_branch=(
(256, ), # offset
(256, ), # depth
(256, ), # size
(256, ), # rot
() # velo
),
dir_branch=(256, ),
attr_branch=(256, ),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_attr=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
norm_on_bbox=True,
centerness_on_reg=True,
center_sampling=True,
conv_bias=True,
dcn_on_last_conv=True),
train_cfg=dict(
allowed_border=0,
code_weight=[1.0, 1.0, 0.2, 1.0, 1.0, 1.0, 1.0, 0.05, 0.05],
pos_weight=-1,
debug=False),
test_cfg=dict(
use_rotate_nms=True,
nms_across_levels=False,
nms_pre=1000,
nms_thr=0.8,
score_thr=0.05,
min_bbox_size=0,
max_per_img=200))
_base_ = [
'../_base_/datasets/nus-mono3d.py', '../_base_/models/fcos3d.py',
'../_base_/schedules/mmdet_schedule_1x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
backbone=dict(
dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, False, True, True)))
class_names = [
'car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle',
'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'
]
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFileMono3D'),
dict(
type='LoadAnnotations3D',
with_bbox=True,
with_label=True,
with_attr_label=True,
with_bbox_3d=True,
with_label_3d=True,
with_bbox_depth=True),
dict(type='Resize', img_scale=(1600, 900), keep_ratio=True),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
keys=[
'img', 'gt_bboxes', 'gt_labels', 'attr_labels', 'gt_bboxes_3d',
'gt_labels_3d', 'centers2d', 'depths'
]),
]
test_pipeline = [
dict(type='LoadImageFromFileMono3D'),
dict(
type='MultiScaleFlipAug',
scale_factor=1.0,
flip=False,
transforms=[
dict(type='RandomFlip3D'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
# optimizer
optimizer = dict(
lr=0.002, paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.))
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[8, 11])
total_epochs = 12
evaluation = dict(interval=2)
from .anchor3d_head import Anchor3DHead
from .anchor_free_mono3d_head import AnchorFreeMono3DHead
from .base_conv_bbox_head import BaseConvBboxHead
from .base_mono3d_dense_head import BaseMono3DDenseHead
from .centerpoint_head import CenterHead
from .fcos_mono3d_head import FCOSMono3DHead
from .free_anchor3d_head import FreeAnchor3DHead
from .parta2_rpn_head import PartA2RPNHead
from .shape_aware_head import ShapeAwareHead
......@@ -9,5 +12,6 @@ from .vote_head import VoteHead
__all__ = [
'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead',
'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead'
'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead',
'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead'
]
This diff is collapsed.
from abc import ABCMeta, abstractmethod
from torch import nn as nn
class BaseMono3DDenseHead(nn.Module, metaclass=ABCMeta):
"""Base class for Monocular 3D DenseHeads."""
def __init__(self):
super(BaseMono3DDenseHead, self).__init__()
@abstractmethod
def loss(self, **kwargs):
"""Compute losses of the head."""
pass
@abstractmethod
def get_bboxes(self, **kwargs):
"""Transform network output for a batch into bbox predictions."""
pass
def forward_train(self,
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
centers2d=None,
depths=None,
attr_labels=None,
gt_bboxes_ignore=None,
proposal_cfg=None,
**kwargs):
"""
Args:
x (list[Tensor]): Features from FPN.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes (list[Tensor]): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_labels (list[Tensor]): Ground truth labels of each box,
shape (num_gts,).
gt_bboxes_3d (list[Tensor]): 3D ground truth bboxes of the image,
shape (num_gts, self.bbox_code_size).
gt_labels_3d (list[Tensor]): 3D ground truth labels of each box,
shape (num_gts,).
centers2d (list[Tensor]): Projected 3D center of each box,
shape (num_gts, 2).
depths (list[Tensor]): Depth of projected 3D center of each box,
shape (num_gts,).
attr_labels (list[Tensor]): Attribute labels of each box,
shape (num_gts,).
gt_bboxes_ignore (list[Tensor]): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
Returns:
tuple:
losses: (dict[str, Tensor]): A dictionary of loss components.
proposal_list (list[Tensor]): Proposals of each image.
"""
outs = self(x)
if gt_labels is None:
loss_inputs = outs + (gt_bboxes, gt_bboxes_3d, centers2d, depths,
attr_labels, img_metas)
else:
loss_inputs = outs + (gt_bboxes, gt_labels, gt_bboxes_3d,
gt_labels_3d, centers2d, depths, attr_labels,
img_metas)
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
if proposal_cfg is None:
return losses
else:
proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
return losses, proposal_list
This diff is collapsed.
......@@ -5,7 +5,8 @@ import random
import torch
from os.path import dirname, exists, join
from mmdet3d.core.bbox import DepthInstance3DBoxes, LiDARInstance3DBoxes
from mmdet3d.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes,
LiDARInstance3DBoxes)
from mmdet3d.models.builder import build_detector
......@@ -316,3 +317,56 @@ def test_centerpoint():
assert boxes_3d_0.tensor.shape[1] == 9
assert scores_3d_0.shape[0] >= 0
assert labels_3d_0.shape[0] >= 0
def test_fcos3d():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
_setup_seed(0)
fcos3d_cfg = _get_detector_cfg(
'fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py')
self = build_detector(fcos3d_cfg).cuda()
imgs = torch.rand([1, 3, 928, 1600], dtype=torch.float32).cuda()
gt_bboxes = [torch.rand([3, 4], dtype=torch.float32).cuda()]
gt_bboxes_3d = CameraInstance3DBoxes(
torch.rand([3, 9], device='cuda'), box_dim=9)
gt_labels = [torch.randint(0, 10, [3], device='cuda')]
gt_labels_3d = gt_labels
centers2d = [torch.rand([3, 2], dtype=torch.float32).cuda()]
depths = [torch.rand([3], dtype=torch.float32).cuda()]
attr_labels = [torch.randint(0, 9, [3], device='cuda')]
img_metas = [
dict(
cam_intrinsic=[[1260.8474446004698, 0.0, 807.968244525554],
[0.0, 1260.8474446004698, 495.3344268742088],
[0.0, 0.0, 1.0]],
scale_factor=np.array([1., 1., 1., 1.], dtype=np.float32),
box_type_3d=CameraInstance3DBoxes)
]
# test forward_train
losses = self.forward_train(imgs, img_metas, gt_bboxes, gt_labels,
gt_bboxes_3d, gt_labels_3d, centers2d, depths,
attr_labels)
assert losses['loss_cls'] >= 0
assert losses['loss_offset'] >= 0
assert losses['loss_depth'] >= 0
assert losses['loss_size'] >= 0
assert losses['loss_rotsin'] >= 0
assert losses['loss_centerness'] >= 0
assert losses['loss_velo'] >= 0
assert losses['loss_dir'] >= 0
assert losses['loss_attr'] >= 0
# test simple_test
results = self.simple_test(imgs, img_metas)
boxes_3d = results[0]['img_bbox']['boxes_3d']
scores_3d = results[0]['img_bbox']['scores_3d']
labels_3d = results[0]['img_bbox']['labels_3d']
attrs_3d = results[0]['img_bbox']['attrs_3d']
assert boxes_3d.tensor.shape[0] >= 0
assert boxes_3d.tensor.shape[1] == 9
assert scores_3d.shape[0] >= 0
assert labels_3d.shape[0] >= 0
assert attrs_3d.shape[0] >= 0
......@@ -5,8 +5,8 @@ import random
import torch
from os.path import dirname, exists, join
from mmdet3d.core.bbox import (Box3DMode, DepthInstance3DBoxes,
LiDARInstance3DBoxes)
from mmdet3d.core.bbox import (Box3DMode, CameraInstance3DBoxes,
DepthInstance3DBoxes, LiDARInstance3DBoxes)
from mmdet3d.models.builder import build_head
from mmdet.apis import set_random_seed
......@@ -1044,3 +1044,73 @@ def test_shape_aware_head_getboxes():
input_metas)
assert len(result_list[0][1]) > 0 # ensure not all boxes are filtered
assert (result_list[0][1] > 0.3).all()
def test_fcos_mono3d_head():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
_setup_seed(0)
fcos3d_head_cfg = _get_head_cfg(
'fcos3d/fcos3d_r101_caffe_fpn_gn-head_dcn_2x8_1x_nus-mono3d.py')
self = build_head(fcos3d_head_cfg).cuda()
feats = [
torch.rand([2, 256, 116, 200], dtype=torch.float32).cuda(),
torch.rand([2, 256, 58, 100], dtype=torch.float32).cuda(),
torch.rand([2, 256, 29, 50], dtype=torch.float32).cuda(),
torch.rand([2, 256, 15, 25], dtype=torch.float32).cuda(),
torch.rand([2, 256, 8, 13], dtype=torch.float32).cuda()
]
# test forward
ret_dict = self(feats)
assert len(ret_dict) == 5
assert len(ret_dict[0]) == 5
assert ret_dict[0][0].shape == torch.Size([2, 10, 116, 200])
# test loss
gt_bboxes = [
torch.rand([3, 4], dtype=torch.float32).cuda(),
torch.rand([3, 4], dtype=torch.float32).cuda()
]
gt_bboxes_3d = CameraInstance3DBoxes(
torch.rand([3, 9], device='cuda'), box_dim=9)
gt_labels = [torch.randint(0, 10, [3], device='cuda') for i in range(2)]
gt_labels_3d = gt_labels
centers2d = [
torch.rand([3, 2], dtype=torch.float32).cuda(),
torch.rand([3, 2], dtype=torch.float32).cuda()
]
depths = [
torch.rand([3], dtype=torch.float32).cuda(),
torch.rand([3], dtype=torch.float32).cuda()
]
attr_labels = [torch.randint(0, 9, [3], device='cuda') for i in range(2)]
img_metas = [
dict(
cam_intrinsic=[[1260.8474446004698, 0.0, 807.968244525554],
[0.0, 1260.8474446004698, 495.3344268742088],
[0.0, 0.0, 1.0]],
scale_factor=np.array([1., 1., 1., 1.], dtype=np.float32),
box_type_3d=CameraInstance3DBoxes) for i in range(2)
]
losses = self.loss(*ret_dict, gt_bboxes, gt_labels, gt_bboxes_3d,
gt_labels_3d, centers2d, depths, attr_labels, img_metas)
assert losses['loss_cls'] >= 0
assert losses['loss_offset'] >= 0
assert losses['loss_depth'] >= 0
assert losses['loss_size'] >= 0
assert losses['loss_rotsin'] >= 0
assert losses['loss_centerness'] >= 0
assert losses['loss_velo'] >= 0
assert losses['loss_dir'] >= 0
assert losses['loss_attr'] >= 0
# test get_boxes
results = self.get_bboxes(*ret_dict, img_metas)
assert len(results) == 2
assert len(results[0]) == 4
assert results[0][0].tensor.shape == torch.Size([200, 9])
assert results[0][1].shape == torch.Size([200])
assert results[0][2].shape == torch.Size([200])
assert results[0][3].shape == torch.Size([200])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment