Unverified Commit 6607f2a7 authored by Jingwei Zhang's avatar Jingwei Zhang Committed by GitHub
Browse files

[Fix] Fix ImVoxelNet and KittiMetric & refactor configs of ImVoxelNet (#1843)

* ImVoxelNet inherits from Base3DDetector

* fix typo sample_id

* replace add_pred_to_datasample with add_pred_to_datasample

* refactor imvoxelnet config
parent cc9eedad
_base_ = [
'../_base_/schedules/mmdet-schedule-1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='ImVoxelNet',
data_preprocessor=dict(
......@@ -151,7 +155,8 @@ test_evaluator = val_evaluator
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0001),
optimizer=dict(
_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}),
clip_grad=dict(max_norm=35., norm_type=2))
......@@ -166,30 +171,7 @@ param_scheduler = [
]
# hooks
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
sampler_seed=dict(type='DistSamplerSeedHook'),
)
# training schedule for 2x
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(checkpoint=dict(type='CheckpointHook', max_keep_ckpts=1))
# runtime
default_scope = 'mmdet3d'
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
log_level = 'INFO'
load_from = None
resume = False
dist_params = dict(backend='nccl')
find_unused_parameters = True # only 1 of 4 FPN outputs is used
......@@ -575,7 +575,7 @@ class KittiMetric(BaseMetric):
box_preds = box_dict['bboxes_3d']
scores = box_dict['scores_3d']
labels = box_dict['labels_3d']
sample_idx = info['sample_id']
sample_idx = info['sample_idx']
box_preds.limit_yaw(offset=0.5, period=np.pi * 2)
if len(box_preds) == 0:
......
......@@ -89,7 +89,7 @@ class Base3DDetector(BaseDetector):
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
def convert_to_datasample(
def add_pred_to_datasample(
self,
data_samples: SampleList,
data_instances_3d: OptInstanceList = None,
......
......@@ -95,6 +95,7 @@ class FCOSMono3D(SingleStageMono3DDetector):
x = self.extract_feat(batch_inputs_dict)
results_list, results_list_2d = self.bbox_head.predict(
x, batch_data_samples, rescale=rescale)
predictions = self.convert_to_datasample(batch_data_samples,
results_list, results_list_2d)
predictions = self.add_pred_to_datasample(batch_data_samples,
results_list,
results_list_2d)
return predictions
......@@ -82,6 +82,6 @@ class GroupFree3DNet(SingleStage3DDetector):
points = batch_inputs_dict['points']
results_list = self.bbox_head.predict(points, x, batch_data_samples,
**kwargs)
predictions = self.convert_to_datasample(batch_data_samples,
predictions = self.add_pred_to_datasample(batch_data_samples,
results_list)
return predictions
......@@ -154,4 +154,4 @@ class H3DNet(TwoStage3DDetector):
feats_dict,
batch_data_samples,
suffix='_optimized')
return self.convert_to_datasample(batch_data_samples, results_list)
return self.add_pred_to_datasample(batch_data_samples, results_list)
......@@ -433,7 +433,7 @@ class ImVoteNet(Base3DDetector):
if points is None:
assert imgs is not None
results_2d = self.predict_img_only(imgs, batch_data_samples)
return self.convert_to_datasample(
return self.add_pred_to_datasample(
batch_data_samples, data_instances_2d=results_2d)
else:
......@@ -488,7 +488,7 @@ class ImVoteNet(Base3DDetector):
batch_data_samples,
rescale=True)
return self.convert_to_datasample(batch_data_samples, results_3d)
return self.add_pred_to_datasample(batch_data_samples, results_3d)
def predict_img_only(self,
imgs: Tensor,
......
......@@ -3,15 +3,15 @@ from typing import List, Tuple, Union
import torch
from mmdet3d.models.detectors import Base3DDetector
from mmdet3d.models.layers.fusion_layers.point_fusion import point_sample
from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import ConfigType, InstanceList, OptConfigType
from mmdet.models.detectors import BaseDetector
from mmdet3d.utils import ConfigType, OptConfigType
@MODELS.register_module()
class ImVoxelNet(BaseDetector):
class ImVoxelNet(Base3DDetector):
r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.
Args:
......@@ -57,31 +57,6 @@ class ImVoxelNet(BaseDetector):
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def convert_to_datasample(self, data_samples: SampleList,
data_instances: InstanceList) -> SampleList:
""" Convert results list to `Det3DDataSample`.
Args:
inputs (list[:obj:`Det3DDataSample`]): The input data.
data_instances (list[:obj:`InstanceData`]): 3D Detection
results of each image.
Returns:
list[:obj:`Det3DDataSample`]: 3D Detection results of the
input images. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
"""
for data_sample, pred_instances_3d in zip(data_samples,
data_instances):
data_sample.pred_instances_3d = pred_instances_3d
return data_samples
def extract_feat(self, batch_inputs_dict: dict,
batch_data_samples: SampleList):
"""Extract 3d features from the backbone -> fpn -> 3d projection.
......@@ -185,7 +160,7 @@ class ImVoxelNet(BaseDetector):
"""
x = self.extract_feat(batch_inputs_dict, batch_data_samples)
results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
predictions = self.convert_to_datasample(batch_data_samples,
predictions = self.add_pred_to_datasample(batch_data_samples,
results_list)
return predictions
......
......@@ -401,7 +401,7 @@ class MVXTwoStageDetector(Base3DDetector):
else:
results_list_2d = None
detsamples = self.convert_to_datasample(batch_data_samples,
detsamples = self.add_pred_to_datasample(batch_data_samples,
results_list_3d,
results_list_2d)
return detsamples
......@@ -108,7 +108,7 @@ class SingleStage3DDetector(Base3DDetector):
"""
x = self.extract_feat(batch_inputs_dict)
results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
predictions = self.convert_to_datasample(batch_data_samples,
predictions = self.add_pred_to_datasample(batch_data_samples,
results_list)
return predictions
......
......@@ -18,7 +18,7 @@ class SingleStageMono3DDetector(SingleStageDetector):
boxes on the output features of the backbone+neck.
"""
def convert_to_datasample(
def add_pred_to_datasample(
self,
data_samples: SampleList,
data_instances_3d: OptInstanceList = None,
......
......@@ -161,7 +161,7 @@ class TwoStage3DDetector(Base3DDetector):
batch_data_samples)
# connvert to Det3DDataSample
results_list = self.convert_to_datasample(batch_data_samples,
results_list = self.add_pred_to_datasample(batch_data_samples,
results_list)
return results_list
......
......@@ -99,7 +99,7 @@ class VoteNet(SingleStage3DDetector):
points = batch_inputs_dict['points']
results_list = self.bbox_head.predict(points, feats_dict,
batch_data_samples, **kwargs)
data_3d_samples = self.convert_to_datasample(batch_data_samples,
data_3d_samples = self.add_pred_to_datasample(batch_data_samples,
results_list)
return data_3d_samples
......@@ -143,6 +143,6 @@ class VoteNet(SingleStage3DDetector):
self.bbox_head.test_cfg)
merged_results = InstanceData(**merged_results_dict)
data_3d_samples = self.convert_to_datasample(batch_data_samples,
data_3d_samples = self.add_pred_to_datasample(batch_data_samples,
[merged_results])
return data_3d_samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment