Unverified Commit 6607f2a7 authored by Jingwei Zhang's avatar Jingwei Zhang Committed by GitHub
Browse files

[Fix] Fix ImVoxelNet and KittiMetric & refactor configs of ImVoxelNet (#1843)

* ImVoxelNet inherits from Base3DDetector

* fix typo sample_id

* replace add_pred_to_datasample with add_pred_to_datasample

* refactor imvoxelnet config
parent cc9eedad
_base_ = [
'../_base_/schedules/mmdet-schedule-1x.py', '../_base_/default_runtime.py'
]
model = dict( model = dict(
type='ImVoxelNet', type='ImVoxelNet',
data_preprocessor=dict( data_preprocessor=dict(
...@@ -151,7 +155,8 @@ test_evaluator = val_evaluator ...@@ -151,7 +155,8 @@ test_evaluator = val_evaluator
# optimizer # optimizer
optim_wrapper = dict( optim_wrapper = dict(
type='OptimWrapper', type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0001), optimizer=dict(
_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0001),
paramwise_cfg=dict( paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}), custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}),
clip_grad=dict(max_norm=35., norm_type=2)) clip_grad=dict(max_norm=35., norm_type=2))
...@@ -166,30 +171,7 @@ param_scheduler = [ ...@@ -166,30 +171,7 @@ param_scheduler = [
] ]
# hooks # hooks
default_hooks = dict( default_hooks = dict(checkpoint=dict(type='CheckpointHook', max_keep_ckpts=1))
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
sampler_seed=dict(type='DistSamplerSeedHook'),
)
# training schedule for 2x
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
# runtime # runtime
default_scope = 'mmdet3d'
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
log_level = 'INFO'
load_from = None
resume = False
dist_params = dict(backend='nccl')
find_unused_parameters = True # only 1 of 4 FPN outputs is used find_unused_parameters = True # only 1 of 4 FPN outputs is used
...@@ -575,7 +575,7 @@ class KittiMetric(BaseMetric): ...@@ -575,7 +575,7 @@ class KittiMetric(BaseMetric):
box_preds = box_dict['bboxes_3d'] box_preds = box_dict['bboxes_3d']
scores = box_dict['scores_3d'] scores = box_dict['scores_3d']
labels = box_dict['labels_3d'] labels = box_dict['labels_3d']
sample_idx = info['sample_id'] sample_idx = info['sample_idx']
box_preds.limit_yaw(offset=0.5, period=np.pi * 2) box_preds.limit_yaw(offset=0.5, period=np.pi * 2)
if len(box_preds) == 0: if len(box_preds) == 0:
......
...@@ -89,7 +89,7 @@ class Base3DDetector(BaseDetector): ...@@ -89,7 +89,7 @@ class Base3DDetector(BaseDetector):
raise RuntimeError(f'Invalid mode "{mode}". ' raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode') 'Only supports loss, predict and tensor mode')
def convert_to_datasample( def add_pred_to_datasample(
self, self,
data_samples: SampleList, data_samples: SampleList,
data_instances_3d: OptInstanceList = None, data_instances_3d: OptInstanceList = None,
......
...@@ -95,6 +95,7 @@ class FCOSMono3D(SingleStageMono3DDetector): ...@@ -95,6 +95,7 @@ class FCOSMono3D(SingleStageMono3DDetector):
x = self.extract_feat(batch_inputs_dict) x = self.extract_feat(batch_inputs_dict)
results_list, results_list_2d = self.bbox_head.predict( results_list, results_list_2d = self.bbox_head.predict(
x, batch_data_samples, rescale=rescale) x, batch_data_samples, rescale=rescale)
predictions = self.convert_to_datasample(batch_data_samples, predictions = self.add_pred_to_datasample(batch_data_samples,
results_list, results_list_2d) results_list,
results_list_2d)
return predictions return predictions
...@@ -82,6 +82,6 @@ class GroupFree3DNet(SingleStage3DDetector): ...@@ -82,6 +82,6 @@ class GroupFree3DNet(SingleStage3DDetector):
points = batch_inputs_dict['points'] points = batch_inputs_dict['points']
results_list = self.bbox_head.predict(points, x, batch_data_samples, results_list = self.bbox_head.predict(points, x, batch_data_samples,
**kwargs) **kwargs)
predictions = self.convert_to_datasample(batch_data_samples, predictions = self.add_pred_to_datasample(batch_data_samples,
results_list) results_list)
return predictions return predictions
...@@ -154,4 +154,4 @@ class H3DNet(TwoStage3DDetector): ...@@ -154,4 +154,4 @@ class H3DNet(TwoStage3DDetector):
feats_dict, feats_dict,
batch_data_samples, batch_data_samples,
suffix='_optimized') suffix='_optimized')
return self.convert_to_datasample(batch_data_samples, results_list) return self.add_pred_to_datasample(batch_data_samples, results_list)
...@@ -433,7 +433,7 @@ class ImVoteNet(Base3DDetector): ...@@ -433,7 +433,7 @@ class ImVoteNet(Base3DDetector):
if points is None: if points is None:
assert imgs is not None assert imgs is not None
results_2d = self.predict_img_only(imgs, batch_data_samples) results_2d = self.predict_img_only(imgs, batch_data_samples)
return self.convert_to_datasample( return self.add_pred_to_datasample(
batch_data_samples, data_instances_2d=results_2d) batch_data_samples, data_instances_2d=results_2d)
else: else:
...@@ -488,7 +488,7 @@ class ImVoteNet(Base3DDetector): ...@@ -488,7 +488,7 @@ class ImVoteNet(Base3DDetector):
batch_data_samples, batch_data_samples,
rescale=True) rescale=True)
return self.convert_to_datasample(batch_data_samples, results_3d) return self.add_pred_to_datasample(batch_data_samples, results_3d)
def predict_img_only(self, def predict_img_only(self,
imgs: Tensor, imgs: Tensor,
......
...@@ -3,15 +3,15 @@ from typing import List, Tuple, Union ...@@ -3,15 +3,15 @@ from typing import List, Tuple, Union
import torch import torch
from mmdet3d.models.detectors import Base3DDetector
from mmdet3d.models.layers.fusion_layers.point_fusion import point_sample from mmdet3d.models.layers.fusion_layers.point_fusion import point_sample
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import ConfigType, InstanceList, OptConfigType from mmdet3d.utils import ConfigType, OptConfigType
from mmdet.models.detectors import BaseDetector
@MODELS.register_module() @MODELS.register_module()
class ImVoxelNet(BaseDetector): class ImVoxelNet(Base3DDetector):
r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_. r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.
Args: Args:
...@@ -57,31 +57,6 @@ class ImVoxelNet(BaseDetector): ...@@ -57,31 +57,6 @@ class ImVoxelNet(BaseDetector):
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
def convert_to_datasample(self, data_samples: SampleList,
data_instances: InstanceList) -> SampleList:
""" Convert results list to `Det3DDataSample`.
Args:
inputs (list[:obj:`Det3DDataSample`]): The input data.
data_instances (list[:obj:`InstanceData`]): 3D Detection
results of each image.
Returns:
list[:obj:`Det3DDataSample`]: 3D Detection results of the
input images. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
"""
for data_sample, pred_instances_3d in zip(data_samples,
data_instances):
data_sample.pred_instances_3d = pred_instances_3d
return data_samples
def extract_feat(self, batch_inputs_dict: dict, def extract_feat(self, batch_inputs_dict: dict,
batch_data_samples: SampleList): batch_data_samples: SampleList):
"""Extract 3d features from the backbone -> fpn -> 3d projection. """Extract 3d features from the backbone -> fpn -> 3d projection.
...@@ -185,8 +160,8 @@ class ImVoxelNet(BaseDetector): ...@@ -185,8 +160,8 @@ class ImVoxelNet(BaseDetector):
""" """
x = self.extract_feat(batch_inputs_dict, batch_data_samples) x = self.extract_feat(batch_inputs_dict, batch_data_samples)
results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs) results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
predictions = self.convert_to_datasample(batch_data_samples, predictions = self.add_pred_to_datasample(batch_data_samples,
results_list) results_list)
return predictions return predictions
def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList, def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
......
...@@ -401,7 +401,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -401,7 +401,7 @@ class MVXTwoStageDetector(Base3DDetector):
else: else:
results_list_2d = None results_list_2d = None
detsamples = self.convert_to_datasample(batch_data_samples, detsamples = self.add_pred_to_datasample(batch_data_samples,
results_list_3d, results_list_3d,
results_list_2d) results_list_2d)
return detsamples return detsamples
...@@ -108,8 +108,8 @@ class SingleStage3DDetector(Base3DDetector): ...@@ -108,8 +108,8 @@ class SingleStage3DDetector(Base3DDetector):
""" """
x = self.extract_feat(batch_inputs_dict) x = self.extract_feat(batch_inputs_dict)
results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs) results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
predictions = self.convert_to_datasample(batch_data_samples, predictions = self.add_pred_to_datasample(batch_data_samples,
results_list) results_list)
return predictions return predictions
def _forward(self, def _forward(self,
......
...@@ -18,7 +18,7 @@ class SingleStageMono3DDetector(SingleStageDetector): ...@@ -18,7 +18,7 @@ class SingleStageMono3DDetector(SingleStageDetector):
boxes on the output features of the backbone+neck. boxes on the output features of the backbone+neck.
""" """
def convert_to_datasample( def add_pred_to_datasample(
self, self,
data_samples: SampleList, data_samples: SampleList,
data_instances_3d: OptInstanceList = None, data_instances_3d: OptInstanceList = None,
......
...@@ -161,8 +161,8 @@ class TwoStage3DDetector(Base3DDetector): ...@@ -161,8 +161,8 @@ class TwoStage3DDetector(Base3DDetector):
batch_data_samples) batch_data_samples)
# connvert to Det3DDataSample # connvert to Det3DDataSample
results_list = self.convert_to_datasample(batch_data_samples, results_list = self.add_pred_to_datasample(batch_data_samples,
results_list) results_list)
return results_list return results_list
......
...@@ -99,8 +99,8 @@ class VoteNet(SingleStage3DDetector): ...@@ -99,8 +99,8 @@ class VoteNet(SingleStage3DDetector):
points = batch_inputs_dict['points'] points = batch_inputs_dict['points']
results_list = self.bbox_head.predict(points, feats_dict, results_list = self.bbox_head.predict(points, feats_dict,
batch_data_samples, **kwargs) batch_data_samples, **kwargs)
data_3d_samples = self.convert_to_datasample(batch_data_samples, data_3d_samples = self.add_pred_to_datasample(batch_data_samples,
results_list) results_list)
return data_3d_samples return data_3d_samples
def aug_test(self, aug_inputs_list: List[dict], def aug_test(self, aug_inputs_list: List[dict],
...@@ -143,6 +143,6 @@ class VoteNet(SingleStage3DDetector): ...@@ -143,6 +143,6 @@ class VoteNet(SingleStage3DDetector):
self.bbox_head.test_cfg) self.bbox_head.test_cfg)
merged_results = InstanceData(**merged_results_dict) merged_results = InstanceData(**merged_results_dict)
data_3d_samples = self.convert_to_datasample(batch_data_samples, data_3d_samples = self.add_pred_to_datasample(batch_data_samples,
[merged_results]) [merged_results])
return data_3d_samples return data_3d_samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment