Unverified Commit c2d958aa authored by ChaimZhu's avatar ChaimZhu Committed by GitHub
Browse files

[Refactor] update data flow and ut (#1776)

* update data flow and ut

* update ut

* update code

* fix mapping bug

* fix comments
parent c2c5abd6
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
from torch import Tensor
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig
from ...structures.det3d_data_sample import SampleList
from .single_stage_mono3d import SingleStageMono3DDetector from .single_stage_mono3d import SingleStageMono3DDetector
...@@ -43,3 +48,53 @@ class FCOSMono3D(SingleStageMono3DDetector): ...@@ -43,3 +48,53 @@ class FCOSMono3D(SingleStageMono3DDetector):
test_cfg=test_cfg, test_cfg=test_cfg,
data_preprocessor=data_preprocessor, data_preprocessor=data_preprocessor,
init_cfg=init_cfg) init_cfg=init_cfg)
def predict(self,
batch_inputs_dict: Dict[str, Tensor],
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs_dict (dict): The model input dict which include
'imgs' keys
- imgs (torch.Tensor: Image of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
rescale (bool): Whether to rescale the results.
Defaults to True.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input. Each Det3DDataSample usually contains
'pred_instances_3d'. And the ``pred_instances_3d`` normally
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instance, )
- labels_3d (Tensor): Labels of 3D bboxes, has a shape
(num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7.
When there are 2D prediction in models, it should
contains `pred_instances`, And the ``pred_instances`` normally
contains following keys.
- scores (Tensor): Classification scores of image, has a shape
(num_instance, )
- labels (Tensor): Predict Labels of 2D bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Contains a tensor with shape
(num_instances, 4).
"""
x = self.extract_feat(batch_inputs_dict)
results_list, results_list_2d = self.bbox_head.predict(
x, batch_data_samples, rescale=rescale)
predictions = self.convert_to_datasample(batch_data_samples,
results_list, results_list_2d)
return predictions
...@@ -82,5 +82,6 @@ class GroupFree3DNet(SingleStage3DDetector): ...@@ -82,5 +82,6 @@ class GroupFree3DNet(SingleStage3DDetector):
points = batch_inputs_dict['points'] points = batch_inputs_dict['points']
results_list = self.bbox_head.predict(points, x, batch_data_samples, results_list = self.bbox_head.predict(points, x, batch_data_samples,
**kwargs) **kwargs)
predictions = self.convert_to_datasample(results_list) predictions = self.convert_to_datasample(batch_data_samples,
results_list)
return predictions return predictions
...@@ -154,4 +154,4 @@ class H3DNet(TwoStage3DDetector): ...@@ -154,4 +154,4 @@ class H3DNet(TwoStage3DDetector):
feats_dict, feats_dict,
batch_data_samples, batch_data_samples,
suffix='_optimized') suffix='_optimized')
return self.convert_to_datasample(results_list) return self.convert_to_datasample(batch_data_samples, results_list)
...@@ -433,7 +433,8 @@ class ImVoteNet(Base3DDetector): ...@@ -433,7 +433,8 @@ class ImVoteNet(Base3DDetector):
if points is None: if points is None:
assert imgs is not None assert imgs is not None
results_2d = self.predict_img_only(imgs, batch_data_samples) results_2d = self.predict_img_only(imgs, batch_data_samples)
return self.convert_to_datasample(results_list_2d=results_2d) return self.convert_to_datasample(
batch_data_samples, data_instances_2d=results_2d)
else: else:
results_2d = self.predict_img_only( results_2d = self.predict_img_only(
...@@ -487,7 +488,7 @@ class ImVoteNet(Base3DDetector): ...@@ -487,7 +488,7 @@ class ImVoteNet(Base3DDetector):
batch_data_samples, batch_data_samples,
rescale=True) rescale=True)
return self.convert_to_datasample(results_3d) return self.convert_to_datasample(batch_data_samples, results_3d)
def predict_img_only(self, def predict_img_only(self,
imgs: Tensor, imgs: Tensor,
......
...@@ -5,7 +5,6 @@ import torch ...@@ -5,7 +5,6 @@ import torch
from mmdet3d.models.layers.fusion_layers.point_fusion import point_sample from mmdet3d.models.layers.fusion_layers.point_fusion import point_sample
from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.registry import MODELS, TASK_UTILS
from mmdet3d.structures import Det3DDataSample
from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import ConfigType, InstanceList, OptConfigType from mmdet3d.utils import ConfigType, InstanceList, OptConfigType
from mmdet.models.detectors import BaseDetector from mmdet.models.detectors import BaseDetector
...@@ -58,13 +57,13 @@ class ImVoxelNet(BaseDetector): ...@@ -58,13 +57,13 @@ class ImVoxelNet(BaseDetector):
self.train_cfg = train_cfg self.train_cfg = train_cfg
self.test_cfg = test_cfg self.test_cfg = test_cfg
def convert_to_datasample(self, results_list: InstanceList) -> SampleList: def convert_to_datasample(self, data_samples: SampleList,
"""Convert results list to `Det3DDataSample`. data_instances: InstanceList) -> SampleList:
""" Convert results list to `Det3DDataSample`.
Args: Args:
results_list (list[:obj:`InstanceData`]): 3D Detection results of inputs (list[:obj:`Det3DDataSample`]): The input data.
each image. data_instances (list[:obj:`InstanceData`]): 3D Detection
results of each image.
Returns: Returns:
list[:obj:`Det3DDataSample`]: 3D Detection results of the list[:obj:`Det3DDataSample`]: 3D Detection results of the
input images. Each Det3DDataSample usually contain input images. Each Det3DDataSample usually contain
...@@ -77,13 +76,11 @@ class ImVoxelNet(BaseDetector): ...@@ -77,13 +76,11 @@ class ImVoxelNet(BaseDetector):
(num_instances, ). (num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape - bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7. (num_instances, C) where C >=7.
""" """
out_results_list = [] for data_sample, pred_instances_3d in zip(data_samples,
for i in range(len(results_list)): data_instances):
result = Det3DDataSample() data_sample.pred_instances_3d = pred_instances_3d
result.pred_instances_3d = results_list[i] return data_samples
out_results_list.append(result)
return out_results_list
def extract_feat(self, batch_inputs_dict: dict, def extract_feat(self, batch_inputs_dict: dict,
batch_data_samples: SampleList): batch_data_samples: SampleList):
...@@ -188,7 +185,8 @@ class ImVoxelNet(BaseDetector): ...@@ -188,7 +185,8 @@ class ImVoxelNet(BaseDetector):
""" """
x = self.extract_feat(batch_inputs_dict, batch_data_samples) x = self.extract_feat(batch_inputs_dict, batch_data_samples)
results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs) results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
predictions = self.convert_to_datasample(results_list) predictions = self.convert_to_datasample(batch_data_samples,
results_list)
return predictions return predictions
def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList, def _forward(self, batch_inputs_dict: dict, batch_data_samples: SampleList,
......
...@@ -218,7 +218,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -218,7 +218,7 @@ class MVXTwoStageDetector(Base3DDetector):
x = self.pts_neck(x) x = self.pts_neck(x)
return x return x
def extract_feat(self, batch_inputs_dict: List[Tensor], def extract_feat(self, batch_inputs_dict: dict,
batch_input_metas: List[dict]) -> tuple: batch_input_metas: List[dict]) -> tuple:
"""Extract features from images and points. """Extract features from images and points.
...@@ -235,9 +235,9 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -235,9 +235,9 @@ class MVXTwoStageDetector(Base3DDetector):
tuple: Two elements in tuple arrange as tuple: Two elements in tuple arrange as
image features and point cloud features. image features and point cloud features.
""" """
voxel_dict = batch_inputs_dict['voxels'] voxel_dict = batch_inputs_dict.get('voxels', None)
imgs = batch_inputs_dict['imgs'] imgs = batch_inputs_dict.get('imgs', None)
points = batch_inputs_dict['points'] points = batch_inputs_dict.get('points', None)
img_feats = self.extract_img_feat(imgs, batch_input_metas) img_feats = self.extract_img_feat(imgs, batch_input_metas)
pts_feats = self.extract_pts_feat( pts_feats = self.extract_pts_feat(
voxel_dict, voxel_dict,
...@@ -401,6 +401,7 @@ class MVXTwoStageDetector(Base3DDetector): ...@@ -401,6 +401,7 @@ class MVXTwoStageDetector(Base3DDetector):
else: else:
results_list_2d = None results_list_2d = None
detsamples = self.convert_to_datasample(results_list_3d, detsamples = self.convert_to_datasample(batch_data_samples,
results_list_3d,
results_list_2d) results_list_2d)
return detsamples return detsamples
...@@ -108,7 +108,8 @@ class SingleStage3DDetector(Base3DDetector): ...@@ -108,7 +108,8 @@ class SingleStage3DDetector(Base3DDetector):
""" """
x = self.extract_feat(batch_inputs_dict) x = self.extract_feat(batch_inputs_dict)
results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs) results_list = self.bbox_head.predict(x, batch_data_samples, **kwargs)
predictions = self.convert_to_datasample(results_list) predictions = self.convert_to_datasample(batch_data_samples,
results_list)
return predictions return predictions
def _forward(self, def _forward(self,
...@@ -137,7 +138,7 @@ class SingleStage3DDetector(Base3DDetector): ...@@ -137,7 +138,7 @@ class SingleStage3DDetector(Base3DDetector):
return results return results
def extract_feat( def extract_feat(
self, batch_inputs_dict: torch.Tensor self, batch_inputs_dict: Dict[str, Tensor]
) -> Union[Tuple[torch.Tensor], Dict[str, Tensor]]: ) -> Union[Tuple[torch.Tensor], Dict[str, Tensor]]:
"""Directly extract features from the backbone+neck. """Directly extract features from the backbone+neck.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple from typing import Tuple
from mmengine.structures import InstanceData
from torch import Tensor from torch import Tensor
from mmdet3d.registry import MODELS from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample
from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.structures.det3d_data_sample import SampleList
from mmdet3d.utils import InstanceList from mmdet3d.utils import OptInstanceList
from mmdet.models.detectors.single_stage import SingleStageDetector from mmdet.models.detectors.single_stage import SingleStageDetector
...@@ -18,38 +18,63 @@ class SingleStageMono3DDetector(SingleStageDetector): ...@@ -18,38 +18,63 @@ class SingleStageMono3DDetector(SingleStageDetector):
boxes on the output features of the backbone+neck. boxes on the output features of the backbone+neck.
""" """
def convert_to_datasample(self, results_list: InstanceList) -> SampleList: def convert_to_datasample(
""" Convert results list to `Det3DDataSample`. self,
data_samples: SampleList,
data_instances_3d: OptInstanceList = None,
data_instances_2d: OptInstanceList = None,
) -> SampleList:
"""Convert results list to `Det3DDataSample`.
Args: Args:
results_list (list[:obj:`InstanceData`]):Detection results data_samples (list[:obj:`Det3DDataSample`]): The input data.
of each image. For each image, it could contains two results data_instances_3d (list[:obj:`InstanceData`], optional): 3D
format: Detection results of each image. Defaults to None.
1. pred_instances_3d data_instances_2d (list[:obj:`InstanceData`], optional): 2D
2. (pred_instances_3d, pred_instances) Detection results of each image. Defaults to None.
Returns: Returns:
list[:obj:`Det3DDataSample`]: 3D Detection results of the list[:obj:`Det3DDataSample`]: Detection results of the
input images. Each Det3DDataSample usually contain input. Each Det3DDataSample usually contains
'pred_instances_3d'. And the ``pred_instances_3d`` usually 'pred_instances_3d'. And the ``pred_instances_3d`` normally
contains following keys. contains following keys.
- scores_3d (Tensor): Classification scores, has a shape - scores_3d (Tensor): Classification scores, has a shape
(num_instance, ) (num_instance, )
- labels_3d (Tensor): Labels of bboxes, has a shape - labels_3d (Tensor): Labels of 3D bboxes, has a shape
(num_instances, ). (num_instances, ).
- bboxes_3d (Tensor): Contains a tensor with shape - bboxes_3d (Tensor): Contains a tensor with shape
(num_instances, C) where C >=7. (num_instances, C) where C >=7.
"""
out_results_list = [] When there are 2D prediction in some models, it should
for i in range(len(results_list)): contains `pred_instances`, And the ``pred_instances`` normally
result = Det3DDataSample() contains following keys.
if len(results_list[i]) == 2:
result.pred_instances_3d = results_list[i][0] - scores (Tensor): Classification scores of image, has a shape
result.pred_instances = results_list[i][1] (num_instance, )
else: - labels (Tensor): Predict Labels of 2D bboxes, has a shape
result.pred_instances_3d = results_list[i] (num_instances, ).
out_results_list.append(result) - bboxes (Tensor): Contains a tensor with shape
return out_results_list (num_instances, 4).
"""
assert (data_instances_2d is not None) or \
(data_instances_3d is not None),\
'please pass at least one type of data_samples'
if data_instances_2d is None:
data_instances_2d = [
InstanceData() for _ in range(len(data_instances_3d))
]
if data_instances_3d is None:
data_instances_3d = [
InstanceData() for _ in range(len(data_instances_2d))
]
for i, data_sample in enumerate(data_samples):
data_sample.pred_instances_3d = data_instances_3d[i]
data_sample.pred_instances = data_instances_2d[i]
return data_samples
def extract_feat(self, batch_inputs_dict: dict) -> Tuple[Tensor]: def extract_feat(self, batch_inputs_dict: dict) -> Tuple[Tensor]:
"""Extract features. """Extract features.
......
...@@ -161,7 +161,8 @@ class TwoStage3DDetector(Base3DDetector): ...@@ -161,7 +161,8 @@ class TwoStage3DDetector(Base3DDetector):
batch_data_samples) batch_data_samples)
# connvert to Det3DDataSample # connvert to Det3DDataSample
results_list = self.convert_to_datasample(results_list) results_list = self.convert_to_datasample(batch_data_samples,
results_list)
return results_list return results_list
......
...@@ -99,7 +99,8 @@ class VoteNet(SingleStage3DDetector): ...@@ -99,7 +99,8 @@ class VoteNet(SingleStage3DDetector):
points = batch_inputs_dict['points'] points = batch_inputs_dict['points']
results_list = self.bbox_head.predict(points, feats_dict, results_list = self.bbox_head.predict(points, feats_dict,
batch_data_samples, **kwargs) batch_data_samples, **kwargs)
data_3d_samples = self.convert_to_datasample(results_list) data_3d_samples = self.convert_to_datasample(batch_data_samples,
results_list)
return data_3d_samples return data_3d_samples
def aug_test(self, aug_inputs_list: List[dict], def aug_test(self, aug_inputs_list: List[dict],
...@@ -142,5 +143,6 @@ class VoteNet(SingleStage3DDetector): ...@@ -142,5 +143,6 @@ class VoteNet(SingleStage3DDetector):
self.bbox_head.test_cfg) self.bbox_head.test_cfg)
merged_results = InstanceData(**merged_results_dict) merged_results = InstanceData(**merged_results_dict)
data_3d_samples = self.convert_to_datasample([merged_results]) data_3d_samples = self.convert_to_datasample(batch_data_samples,
[merged_results])
return data_3d_samples return data_3d_samples
...@@ -470,8 +470,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -470,8 +470,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
def add_datasample(self, def add_datasample(self,
name: str, name: str,
data_input: dict, data_input: dict,
gt_sample: Optional['Det3DDataSample'] = None, data_sample: Optional['Det3DDataSample'] = None,
pred_sample: Optional['Det3DDataSample'] = None,
draw_gt: bool = True, draw_gt: bool = True,
draw_pred: bool = True, draw_pred: bool = True,
show: bool = False, show: bool = False,
...@@ -495,9 +494,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -495,9 +494,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
name (str): The image identifier. name (str): The image identifier.
data_input (dict): It should include the point clouds or image data_input (dict): It should include the point clouds or image
to draw. to draw.
gt_sample (:obj:`Det3DDataSample`, optional): GT Det3DDataSample. data_sample (:obj:`Det3DDataSample`, optional): Prediction
Defaults to None.
pred_sample (:obj:`Det3DDataSample`, optional): Prediction
Det3DDataSample. Defaults to None. Det3DDataSample. Defaults to None.
draw_gt (bool): Whether to draw GT Det3DDataSample. draw_gt (bool): Whether to draw GT Det3DDataSample.
Default to True. Default to True.
...@@ -524,20 +521,20 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -524,20 +521,20 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
gt_img_data = None gt_img_data = None
pred_img_data = None pred_img_data = None
if draw_gt and gt_sample is not None: if draw_gt and data_sample is not None:
if 'gt_instances_3d' in gt_sample: if 'gt_instances_3d' in data_sample:
gt_data_3d = self._draw_instances_3d(data_input, gt_data_3d = self._draw_instances_3d(
gt_sample.gt_instances_3d, data_input, data_sample.gt_instances_3d,
gt_sample.metainfo, data_sample.metainfo, vis_task, palette)
vis_task, palette) if 'gt_instances' in data_sample:
if 'gt_instances' in gt_sample:
assert 'img' in data_input assert 'img' in data_input
if isinstance(data_input['img'], Tensor): if isinstance(data_input['img'], Tensor):
img = data_input['img'].permute(1, 2, 0).numpy() img = data_input['img'].permute(1, 2, 0).numpy()
img = img[..., [2, 1, 0]] # bgr to rgb img = img[..., [2, 1, 0]] # bgr to rgb
gt_img_data = self._draw_instances(img, gt_sample.gt_instances, gt_img_data = self._draw_instances(img,
data_sample.gt_instances,
classes, palette) classes, palette)
if 'gt_pts_seg' in gt_sample: if 'gt_pts_seg' in data_sample:
assert classes is not None, 'class information is ' \ assert classes is not None, 'class information is ' \
'not provided when ' \ 'not provided when ' \
'visualizing panoptic ' \ 'visualizing panoptic ' \
...@@ -545,23 +542,23 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -545,23 +542,23 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
assert 'points' in data_input assert 'points' in data_input
gt_seg_data_3d = \ gt_seg_data_3d = \
self._draw_pts_sem_seg(data_input['points'], self._draw_pts_sem_seg(data_input['points'],
pred_sample.pred_pts_seg, data_sample.pred_pts_seg,
palette, ignore_index) palette, ignore_index)
if draw_pred and pred_sample is not None: if draw_pred and data_sample is not None:
if 'pred_instances_3d' in pred_sample: if 'pred_instances_3d' in data_sample:
pred_instances_3d = pred_sample.pred_instances_3d pred_instances_3d = data_sample.pred_instances_3d
# .cpu can not be used for BaseInstancesBoxes3D # .cpu can not be used for BaseInstancesBoxes3D
# so we need to use .to('cpu') # so we need to use .to('cpu')
pred_instances_3d = pred_instances_3d[ pred_instances_3d = pred_instances_3d[
pred_instances_3d.scores_3d > pred_score_thr].to('cpu') pred_instances_3d.scores_3d > pred_score_thr].to('cpu')
pred_data_3d = self._draw_instances_3d(data_input, pred_data_3d = self._draw_instances_3d(data_input,
pred_instances_3d, pred_instances_3d,
pred_sample.metainfo, data_sample.metainfo,
vis_task, palette) vis_task, palette)
if 'pred_instances' in pred_sample: if 'pred_instances' in data_sample:
if 'img' in data_input and len(pred_sample.pred_instances) > 0: if 'img' in data_input and len(data_sample.pred_instances) > 0:
pred_instances = pred_sample.pred_instances pred_instances = data_sample.pred_instances
pred_instances = pred_instances_3d[ pred_instances = pred_instances_3d[
pred_instances.scores > pred_score_thr].cpu() pred_instances.scores > pred_score_thr].cpu()
if isinstance(data_input['img'], Tensor): if isinstance(data_input['img'], Tensor):
...@@ -569,7 +566,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -569,7 +566,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
img = img[..., [2, 1, 0]] # bgr to rgb img = img[..., [2, 1, 0]] # bgr to rgb
pred_img_data = self._draw_instances( pred_img_data = self._draw_instances(
img, pred_instances, classes, palette) img, pred_instances, classes, palette)
if 'pred_pts_seg' in pred_sample: if 'pred_pts_seg' in data_sample:
assert classes is not None, 'class information is ' \ assert classes is not None, 'class information is ' \
'not provided when ' \ 'not provided when ' \
'visualizing panoptic ' \ 'visualizing panoptic ' \
...@@ -577,7 +574,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ...@@ -577,7 +574,7 @@ class Det3DLocalVisualizer(DetLocalVisualizer):
assert 'points' in data_input assert 'points' in data_input
pred_seg_data_3d = \ pred_seg_data_3d = \
self._draw_pts_sem_seg(data_input['points'], self._draw_pts_sem_seg(data_input['points'],
pred_sample.pred_pts_seg, data_sample.pred_pts_seg,
palette, ignore_index) palette, ignore_index)
# monocular 3d object detection image # monocular 3d object detection image
......
...@@ -34,7 +34,7 @@ def _generate_nus_dataset_config(): ...@@ -34,7 +34,7 @@ def _generate_nus_dataset_config():
dict(type='Identity'), dict(type='Identity'),
] ]
modality = dict(use_lidar=True, use_camera=False) modality = dict(use_lidar=True, use_camera=False)
data_prefix = dict(pts='lidar', img='') data_prefix = dict(pts='lidar', img='', sweeps='sweeps/LIDAR_TOP')
return data_root, ann_file, classes, data_prefix, pipeline, modality return data_root, ann_file, classes, data_prefix, pipeline, modality
......
...@@ -34,7 +34,10 @@ def _generate_nus_dataset_config(): ...@@ -34,7 +34,10 @@ def _generate_nus_dataset_config():
dict(type='Identity'), dict(type='Identity'),
] ]
modality = dict(use_lidar=True, use_camera=True) modality = dict(use_lidar=True, use_camera=True)
data_prefix = dict(pts='samples/LIDAR_TOP', img='samples/CAM_BACK_LEFT') data_prefix = dict(
pts='samples/LIDAR_TOP',
img='samples/CAM_BACK_LEFT',
sweeps='sweeps/LIDAR_TOP')
return data_root, ann_file, classes, data_prefix, pipeline, modality return data_root, ann_file, classes, data_prefix, pipeline, modality
......
...@@ -76,7 +76,7 @@ class TestS3DISDataset(unittest.TestCase): ...@@ -76,7 +76,7 @@ class TestS3DISDataset(unittest.TestCase):
input_dict = s3dis_seg_dataset.prepare_data(0) input_dict = s3dis_seg_dataset.prepare_data(0)
points = input_dict['inputs']['points'] points = input_dict['inputs']['points']
data_sample = input_dict['data_sample'] data_sample = input_dict['data_samples']
pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
expected_points = torch.tensor([[ expected_points = torch.tensor([[
......
...@@ -178,7 +178,7 @@ class TestScanNetDataset(unittest.TestCase): ...@@ -178,7 +178,7 @@ class TestScanNetDataset(unittest.TestCase):
input_dict = scannet_seg_dataset.prepare_data(0) input_dict = scannet_seg_dataset.prepare_data(0)
points = input_dict['inputs']['points'] points = input_dict['inputs']['points']
data_sample = input_dict['data_sample'] data_sample = input_dict['data_samples']
pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
expected_points = torch.tensor([[ expected_points = torch.tensor([[
......
...@@ -80,6 +80,6 @@ class TestSemanticKITTIDataset(unittest.TestCase): ...@@ -80,6 +80,6 @@ class TestSemanticKITTIDataset(unittest.TestCase):
input_dict = semantickitti_dataset.prepare_data(0) input_dict = semantickitti_dataset.prepare_data(0)
points = input_dict['inputs']['points'] points = input_dict['inputs']['points']
data_sample = input_dict['data_sample'] data_sample = input_dict['data_samples']
pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask
self.assertEqual(points.shape[0], pts_semantic_mask.shape[0]) self.assertEqual(points.shape[0], pts_semantic_mask.shape[0])
...@@ -19,7 +19,7 @@ class TestPack3DDetInputs(unittest.TestCase): ...@@ -19,7 +19,7 @@ class TestPack3DDetInputs(unittest.TestCase):
inputs = packed_results['inputs'] inputs = packed_results['inputs']
# annotations # annotations
gt_instances = packed_results['data_sample'].gt_instances_3d gt_instances = packed_results['data_samples'].gt_instances_3d
self.assertIn('points', inputs) self.assertIn('points', inputs)
self.assertIsInstance(inputs['points'], torch.Tensor) self.assertIsInstance(inputs['points'], torch.Tensor)
assert_allclose(inputs['points'].sum(), torch.tensor(13062.6436)) assert_allclose(inputs['points'].sum(), torch.tensor(13062.6436))
......
...@@ -14,9 +14,6 @@ class TestIndoorMetric(unittest.TestCase): ...@@ -14,9 +14,6 @@ class TestIndoorMetric(unittest.TestCase):
@patch('sys.stdout', new_callable=StringIO) @patch('sys.stdout', new_callable=StringIO)
def test_process(self, stdout): def test_process(self, stdout):
indoor_metric = IndoorMetric() indoor_metric = IndoorMetric()
dummy_batch = dict(data_sample=dict())
eval_ann_info = { eval_ann_info = {
'gt_bboxes_3d': 'gt_bboxes_3d':
DepthInstance3DBoxes( DepthInstance3DBoxes(
...@@ -40,7 +37,6 @@ class TestIndoorMetric(unittest.TestCase): ...@@ -40,7 +37,6 @@ class TestIndoorMetric(unittest.TestCase):
'gt_labels_3d': 'gt_labels_3d':
np.array([2, 2, 2, 3, 4, 17, 4, 7, 2, 8, 17, 11]) np.array([2, 2, 2, 3, 4, 17, 4, 7, 2, 8, 17, 11])
} }
dummy_batch['data_sample']['eval_ann_info'] = eval_ann_info
pred_instances_3d = dict() pred_instances_3d = dict()
pred_instances_3d['scores_3d'] = torch.ones( pred_instances_3d['scores_3d'] = torch.ones(
...@@ -50,6 +46,8 @@ class TestIndoorMetric(unittest.TestCase): ...@@ -50,6 +46,8 @@ class TestIndoorMetric(unittest.TestCase):
eval_ann_info['gt_labels_3d']) eval_ann_info['gt_labels_3d'])
pred_dict = dict() pred_dict = dict()
pred_dict['pred_instances_3d'] = pred_instances_3d pred_dict['pred_instances_3d'] = pred_instances_3d
pred_dict['eval_ann_info'] = eval_ann_info
indoor_metric.dataset_meta = { indoor_metric.dataset_meta = {
'CLASSES': ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'CLASSES': ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door',
'window', 'bookshelf', 'picture', 'counter', 'desk', 'window', 'bookshelf', 'picture', 'counter', 'desk',
...@@ -58,7 +56,8 @@ class TestIndoorMetric(unittest.TestCase): ...@@ -58,7 +56,8 @@ class TestIndoorMetric(unittest.TestCase):
'box_type_3d': 'box_type_3d':
'Depth', 'Depth',
} }
indoor_metric.process([dummy_batch], [pred_dict])
indoor_metric.process({}, [pred_dict])
eval_results = indoor_metric.evaluate(1) eval_results = indoor_metric.evaluate(1)
for v in eval_results.values(): for v in eval_results.values():
......
...@@ -3,7 +3,7 @@ import unittest ...@@ -3,7 +3,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from mmengine.data import BaseDataElement from mmengine.structures import BaseDataElement
from mmdet3d.evaluation.metrics import InstanceSegMetric from mmdet3d.evaluation.metrics import InstanceSegMetric
from mmdet3d.structures import Det3DDataSample, PointData from mmdet3d.structures import Det3DDataSample, PointData
...@@ -11,10 +11,9 @@ from mmdet3d.structures import Det3DDataSample, PointData ...@@ -11,10 +11,9 @@ from mmdet3d.structures import Det3DDataSample, PointData
class TestInstanceSegMetric(unittest.TestCase): class TestInstanceSegMetric(unittest.TestCase):
def _demo_mm_inputs(self): def _demo_mm_model_output(self):
"""Create a superset of inputs needed to run test or train batches.""" """Create a superset of inputs needed to run test or train batches."""
packed_inputs = []
mm_inputs = dict()
n_points = 3300 n_points = 3300
gt_labels = [0, 0, 0, 0, 0, 0, 14, 14, 2, 1] gt_labels = [0, 0, 0, 0, 0, 0, 14, 14, 2, 1]
gt_instance_mask = np.ones(n_points, dtype=np.int) * -1 gt_instance_mask = np.ones(n_points, dtype=np.int) * -1
...@@ -29,13 +28,6 @@ class TestInstanceSegMetric(unittest.TestCase): ...@@ -29,13 +28,6 @@ class TestInstanceSegMetric(unittest.TestCase):
ann_info_data['pts_instance_mask'] = torch.tensor(gt_instance_mask) ann_info_data['pts_instance_mask'] = torch.tensor(gt_instance_mask)
ann_info_data['pts_semantic_mask'] = torch.tensor(gt_semantic_mask) ann_info_data['pts_semantic_mask'] = torch.tensor(gt_semantic_mask)
mm_inputs['data_sample'] = dict(eval_ann_info=ann_info_data)
packed_inputs.append(mm_inputs)
return packed_inputs
def _demo_mm_model_output(self):
"""Create a superset of inputs needed to run test or train batches."""
results_dict = dict() results_dict = dict()
n_points = 3300 n_points = 3300
gt_labels = [0, 0, 0, 0, 0, 0, 14, 14, 2, 1] gt_labels = [0, 0, 0, 0, 0, 0, 14, 14, 2, 1]
...@@ -54,6 +46,7 @@ class TestInstanceSegMetric(unittest.TestCase): ...@@ -54,6 +46,7 @@ class TestInstanceSegMetric(unittest.TestCase):
results_dict['instance_scores'] = torch.tensor(scores) results_dict['instance_scores'] = torch.tensor(scores)
data_sample = Det3DDataSample() data_sample = Det3DDataSample()
data_sample.pred_pts_seg = PointData(**results_dict) data_sample.pred_pts_seg = PointData(**results_dict)
data_sample.eval_ann_info = ann_info_data
batch_data_samples = [data_sample] batch_data_samples = [data_sample]
predictions = [] predictions = []
...@@ -65,7 +58,7 @@ class TestInstanceSegMetric(unittest.TestCase): ...@@ -65,7 +58,7 @@ class TestInstanceSegMetric(unittest.TestCase):
return predictions return predictions
def test_evaluate(self): def test_evaluate(self):
data_batch = self._demo_mm_inputs() data_batch = {}
predictions = self._demo_mm_model_output() predictions = self._demo_mm_model_output()
seg_valid_class_ids = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, seg_valid_class_ids = (3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28,
33, 34, 36, 39) 33, 34, 36, 39)
......
...@@ -10,7 +10,7 @@ data_root = 'tests/data/kitti' ...@@ -10,7 +10,7 @@ data_root = 'tests/data/kitti'
def _init_evaluate_input(): def _init_evaluate_input():
data_batch = [dict(data_sample=dict(sample_idx=0))] metainfo = dict(sample_idx=0)
predictions = Det3DDataSample() predictions = Det3DDataSample()
pred_instances_3d = InstanceData() pred_instances_3d = InstanceData()
pred_instances_3d.bboxes_3d = LiDARInstance3DBoxes( pred_instances_3d.bboxes_3d = LiDARInstance3DBoxes(
...@@ -21,12 +21,13 @@ def _init_evaluate_input(): ...@@ -21,12 +21,13 @@ def _init_evaluate_input():
predictions.pred_instances_3d = pred_instances_3d predictions.pred_instances_3d = pred_instances_3d
predictions.pred_instances = InstanceData() predictions.pred_instances = InstanceData()
predictions.set_metainfo(metainfo)
predictions = predictions.to_dict() predictions = predictions.to_dict()
return data_batch, [predictions] return {}, [predictions]
def _init_multi_modal_evaluate_input(): def _init_multi_modal_evaluate_input():
data_batch = [dict(data_sample=dict(sample_idx=0))] metainfo = dict(sample_idx=0)
predictions = Det3DDataSample() predictions = Det3DDataSample()
pred_instances_3d = InstanceData() pred_instances_3d = InstanceData()
pred_instances = InstanceData() pred_instances = InstanceData()
...@@ -42,8 +43,9 @@ def _init_multi_modal_evaluate_input(): ...@@ -42,8 +43,9 @@ def _init_multi_modal_evaluate_input():
predictions.pred_instances_3d = pred_instances_3d predictions.pred_instances_3d = pred_instances_3d
predictions.pred_instances = pred_instances predictions.pred_instances = pred_instances
predictions.set_metainfo(metainfo)
predictions = predictions.to_dict() predictions = predictions.to_dict()
return data_batch, [predictions] return {}, [predictions]
def test_multi_modal_kitti_metric(): def test_multi_modal_kitti_metric():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment