Unverified Commit 32a4328b authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Bump version to V1.0.0rc0

Bump version to V1.0.0rc0
parents 86cc487c a8817998
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os
import tempfile
from os import path as osp
import mmcv import mmcv
import numpy as np import numpy as np
import os
import pandas as pd import pandas as pd
import tempfile
from lyft_dataset_sdk.lyftdataset import LyftDataset as Lyft from lyft_dataset_sdk.lyftdataset import LyftDataset as Lyft
from lyft_dataset_sdk.utils.data_classes import Box as LyftBox from lyft_dataset_sdk.utils.data_classes import Box as LyftBox
from os import path as osp
from pyquaternion import Quaternion from pyquaternion import Quaternion
from mmdet3d.core.evaluation.lyft_eval import lyft_eval from mmdet3d.core.evaluation.lyft_eval import lyft_eval
...@@ -129,7 +130,7 @@ class LyftDataset(Custom3DDataset): ...@@ -129,7 +130,7 @@ class LyftDataset(Custom3DDataset):
index (int): Index of the sample data to get. index (int): Index of the sample data to get.
Returns: Returns:
dict: Data information that will be passed to the data \ dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys: preprocessing pipelines. It includes the following keys:
- sample_idx (str): sample index - sample_idx (str): sample index
...@@ -137,13 +138,13 @@ class LyftDataset(Custom3DDataset): ...@@ -137,13 +138,13 @@ class LyftDataset(Custom3DDataset):
- sweeps (list[dict]): infos of sweeps - sweeps (list[dict]): infos of sweeps
- timestamp (float): sample timestamp - timestamp (float): sample timestamp
- img_filename (str, optional): image filename - img_filename (str, optional): image filename
- lidar2img (list[np.ndarray], optional): transformations \ - lidar2img (list[np.ndarray], optional): transformations
from lidar to different cameras from lidar to different cameras
- ann_info (dict): annotation info - ann_info (dict): annotation info
""" """
info = self.data_infos[index] info = self.data_infos[index]
# standard protocal modified from SECOND.Pytorch # standard protocol modified from SECOND.Pytorch
input_dict = dict( input_dict = dict(
sample_idx=info['token'], sample_idx=info['token'],
pts_filename=info['lidar_path'], pts_filename=info['lidar_path'],
...@@ -190,7 +191,7 @@ class LyftDataset(Custom3DDataset): ...@@ -190,7 +191,7 @@ class LyftDataset(Custom3DDataset):
Returns: Returns:
dict: Annotation information consists of the following keys: dict: Annotation information consists of the following keys:
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): \ - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
3D ground truth bboxes. 3D ground truth bboxes.
- gt_labels_3d (np.ndarray): Labels of ground truths. - gt_labels_3d (np.ndarray): Labels of ground truths.
- gt_names (list[str]): Class names of ground truths. - gt_names (list[str]): Class names of ground truths.
...@@ -275,10 +276,11 @@ class LyftDataset(Custom3DDataset): ...@@ -275,10 +276,11 @@ class LyftDataset(Custom3DDataset):
Args: Args:
result_path (str): Path of the result file. result_path (str): Path of the result file.
logger (logging.Logger | str | None): Logger used for printing logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
metric (str): Metric name used for evaluation. Default: 'bbox'. metric (str, optional): Metric name used for evaluation.
result_name (str): Result name in the metric prefix. Default: 'bbox'.
result_name (str, optional): Result name in the metric prefix.
Default: 'pts_bbox'. Default: 'pts_bbox'.
Returns: Returns:
...@@ -312,18 +314,18 @@ class LyftDataset(Custom3DDataset): ...@@ -312,18 +314,18 @@ class LyftDataset(Custom3DDataset):
Args: Args:
results (list[dict]): Testing results of the dataset. results (list[dict]): Testing results of the dataset.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
csv_savepath (str | None): The path for saving csv files. csv_savepath (str): The path for saving csv files.
It includes the file path and the csv filename, It includes the file path and the csv filename,
e.g., "a/b/filename.csv". If not specified, e.g., "a/b/filename.csv". If not specified,
the result will not be converted to csv file. the result will not be converted to csv file.
Returns: Returns:
tuple: Returns (result_files, tmp_dir), where `result_files` is a \ tuple: Returns (result_files, tmp_dir), where `result_files` is a
dict containing the json filepaths, `tmp_dir` is the temporal \ dict containing the json filepaths, `tmp_dir` is the temporal
directory created for saving json files when \ directory created for saving json files when
`jsonfile_prefix` is not specified. `jsonfile_prefix` is not specified.
""" """
assert isinstance(results, list), 'results must be a list' assert isinstance(results, list), 'results must be a list'
...@@ -372,19 +374,22 @@ class LyftDataset(Custom3DDataset): ...@@ -372,19 +374,22 @@ class LyftDataset(Custom3DDataset):
Args: Args:
results (list[dict]): Testing results of the dataset. results (list[dict]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str], optional): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing Default: 'bbox'.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str, optional): The prefix of json files including
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
csv_savepath (str | None): The path for saving csv files. csv_savepath (str, optional): The path for saving csv files.
It includes the file path and the csv filename, It includes the file path and the csv filename,
e.g., "a/b/filename.csv". If not specified, e.g., "a/b/filename.csv". If not specified,
the result will not be converted to csv file. the result will not be converted to csv file.
show (bool): Whether to visualize. result_names (list[str], optional): Result names in the
metric prefix. Default: ['pts_bbox'].
show (bool, optional): Whether to visualize.
Default: False. Default: False.
out_dir (str): Path to save the visualization results. out_dir (str, optional): Path to save the visualization results.
Default: None. Default: None.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
...@@ -407,8 +412,8 @@ class LyftDataset(Custom3DDataset): ...@@ -407,8 +412,8 @@ class LyftDataset(Custom3DDataset):
if tmp_dir is not None: if tmp_dir is not None:
tmp_dir.cleanup() tmp_dir.cleanup()
if show: if show or out_dir:
self.show(results, out_dir, pipeline=pipeline) self.show(results, out_dir, show=show, pipeline=pipeline)
return results_dict return results_dict
def _build_default_pipeline(self): def _build_default_pipeline(self):
...@@ -432,13 +437,14 @@ class LyftDataset(Custom3DDataset): ...@@ -432,13 +437,14 @@ class LyftDataset(Custom3DDataset):
] ]
return Compose(pipeline) return Compose(pipeline)
def show(self, results, out_dir, show=True, pipeline=None): def show(self, results, out_dir, show=False, pipeline=None):
"""Results visualization. """Results visualization.
Args: Args:
results (list[dict]): List of bounding boxes results. results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result. out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online. show (bool): Whether to visualize the results online.
Default: False.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
""" """
...@@ -517,16 +523,16 @@ def output_to_lyft_box(detection): ...@@ -517,16 +523,16 @@ def output_to_lyft_box(detection):
box_gravity_center = box3d.gravity_center.numpy() box_gravity_center = box3d.gravity_center.numpy()
box_dims = box3d.dims.numpy() box_dims = box3d.dims.numpy()
box_yaw = box3d.yaw.numpy() box_yaw = box3d.yaw.numpy()
# TODO: check whether this is necessary
# with dir_offset & dir_limit in the head # our LiDAR coordinate system -> Lyft box coordinate system
box_yaw = -box_yaw - np.pi / 2 lyft_box_dims = box_dims[:, [1, 0, 2]]
box_list = [] box_list = []
for i in range(len(box3d)): for i in range(len(box3d)):
quat = Quaternion(axis=[0, 0, 1], radians=box_yaw[i]) quat = Quaternion(axis=[0, 0, 1], radians=box_yaw[i])
box = LyftBox( box = LyftBox(
box_gravity_center[i], box_gravity_center[i],
box_dims[i], lyft_box_dims[i],
quat, quat,
label=labels[i], label=labels[i],
score=scores[i]) score=scores[i])
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import tempfile
from os import path as osp
import mmcv import mmcv
import numpy as np import numpy as np
import pyquaternion import pyquaternion
import tempfile
from nuscenes.utils.data_classes import Box as NuScenesBox from nuscenes.utils.data_classes import Box as NuScenesBox
from os import path as osp
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
from ..core import show_result from ..core import show_result
...@@ -48,8 +49,9 @@ class NuScenesDataset(Custom3DDataset): ...@@ -48,8 +49,9 @@ class NuScenesDataset(Custom3DDataset):
Defaults to False. Defaults to False.
eval_version (bool, optional): Configuration version of evaluation. eval_version (bool, optional): Configuration version of evaluation.
Defaults to 'detection_cvpr_2019'. Defaults to 'detection_cvpr_2019'.
use_valid_flag (bool): Whether to use `use_valid_flag` key in the info use_valid_flag (bool, optional): Whether to use `use_valid_flag` key
file as mask to filter gt_boxes and gt_names. Defaults to False. in the info file as mask to filter gt_boxes and gt_names.
Defaults to False.
""" """
NameMapping = { NameMapping = {
'movable_object.barrier': 'barrier', 'movable_object.barrier': 'barrier',
...@@ -196,7 +198,7 @@ class NuScenesDataset(Custom3DDataset): ...@@ -196,7 +198,7 @@ class NuScenesDataset(Custom3DDataset):
index (int): Index of the sample data to get. index (int): Index of the sample data to get.
Returns: Returns:
dict: Data information that will be passed to the data \ dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys: preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index. - sample_idx (str): Sample index.
...@@ -204,12 +206,12 @@ class NuScenesDataset(Custom3DDataset): ...@@ -204,12 +206,12 @@ class NuScenesDataset(Custom3DDataset):
- sweeps (list[dict]): Infos of sweeps. - sweeps (list[dict]): Infos of sweeps.
- timestamp (float): Sample timestamp. - timestamp (float): Sample timestamp.
- img_filename (str, optional): Image filename. - img_filename (str, optional): Image filename.
- lidar2img (list[np.ndarray], optional): Transformations \ - lidar2img (list[np.ndarray], optional): Transformations
from lidar to different cameras. from lidar to different cameras.
- ann_info (dict): Annotation info. - ann_info (dict): Annotation info.
""" """
info = self.data_infos[index] info = self.data_infos[index]
# standard protocal modified from SECOND.Pytorch # standard protocol modified from SECOND.Pytorch
input_dict = dict( input_dict = dict(
sample_idx=info['token'], sample_idx=info['token'],
pts_filename=info['lidar_path'], pts_filename=info['lidar_path'],
...@@ -256,7 +258,7 @@ class NuScenesDataset(Custom3DDataset): ...@@ -256,7 +258,7 @@ class NuScenesDataset(Custom3DDataset):
Returns: Returns:
dict: Annotation information consists of the following keys: dict: Annotation information consists of the following keys:
- gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`): \ - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes`):
3D ground truth bboxes 3D ground truth bboxes
- gt_labels_3d (np.ndarray): Labels of ground truths. - gt_labels_3d (np.ndarray): Labels of ground truths.
- gt_names (list[str]): Class names of ground truths. - gt_names (list[str]): Class names of ground truths.
...@@ -374,10 +376,11 @@ class NuScenesDataset(Custom3DDataset): ...@@ -374,10 +376,11 @@ class NuScenesDataset(Custom3DDataset):
Args: Args:
result_path (str): Path of the result file. result_path (str): Path of the result file.
logger (logging.Logger | str | None): Logger used for printing logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
metric (str): Metric name used for evaluation. Default: 'bbox'. metric (str, optional): Metric name used for evaluation.
result_name (str): Result name in the metric prefix. Default: 'bbox'.
result_name (str, optional): Result name in the metric prefix.
Default: 'pts_bbox'. Default: 'pts_bbox'.
Returns: Returns:
...@@ -427,14 +430,14 @@ class NuScenesDataset(Custom3DDataset): ...@@ -427,14 +430,14 @@ class NuScenesDataset(Custom3DDataset):
Args: Args:
results (list[dict]): Testing results of the dataset. results (list[dict]): Testing results of the dataset.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
Returns: Returns:
tuple: Returns (result_files, tmp_dir), where `result_files` is a \ tuple: Returns (result_files, tmp_dir), where `result_files` is a
dict containing the json filepaths, `tmp_dir` is the temporal \ dict containing the json filepaths, `tmp_dir` is the temporal
directory created for saving json files when \ directory created for saving json files when
`jsonfile_prefix` is not specified. `jsonfile_prefix` is not specified.
""" """
assert isinstance(results, list), 'results must be a list' assert isinstance(results, list), 'results must be a list'
...@@ -480,15 +483,16 @@ class NuScenesDataset(Custom3DDataset): ...@@ -480,15 +483,16 @@ class NuScenesDataset(Custom3DDataset):
Args: Args:
results (list[dict]): Testing results of the dataset. results (list[dict]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str], optional): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing Default: 'bbox'.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str, optional): The prefix of json files including
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
show (bool): Whether to visualize. show (bool, optional): Whether to visualize.
Default: False. Default: False.
out_dir (str): Path to save the visualization results. out_dir (str, optional): Path to save the visualization results.
Default: None. Default: None.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
...@@ -510,8 +514,8 @@ class NuScenesDataset(Custom3DDataset): ...@@ -510,8 +514,8 @@ class NuScenesDataset(Custom3DDataset):
if tmp_dir is not None: if tmp_dir is not None:
tmp_dir.cleanup() tmp_dir.cleanup()
if show: if show or out_dir:
self.show(results, out_dir, pipeline=pipeline) self.show(results, out_dir, show=show, pipeline=pipeline)
return results_dict return results_dict
def _build_default_pipeline(self): def _build_default_pipeline(self):
...@@ -535,13 +539,14 @@ class NuScenesDataset(Custom3DDataset): ...@@ -535,13 +539,14 @@ class NuScenesDataset(Custom3DDataset):
] ]
return Compose(pipeline) return Compose(pipeline)
def show(self, results, out_dir, show=True, pipeline=None): def show(self, results, out_dir, show=False, pipeline=None):
"""Results visualization. """Results visualization.
Args: Args:
results (list[dict]): List of bounding boxes results. results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result. out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online. show (bool): Whether to visualize the results online.
Default: False.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
""" """
...@@ -588,9 +593,9 @@ def output_to_nusc_box(detection): ...@@ -588,9 +593,9 @@ def output_to_nusc_box(detection):
box_gravity_center = box3d.gravity_center.numpy() box_gravity_center = box3d.gravity_center.numpy()
box_dims = box3d.dims.numpy() box_dims = box3d.dims.numpy()
box_yaw = box3d.yaw.numpy() box_yaw = box3d.yaw.numpy()
# TODO: check whether this is necessary
# with dir_offset & dir_limit in the head # our LiDAR coordinate system -> nuScenes box coordinate system
box_yaw = -box_yaw - np.pi / 2 nus_box_dims = box_dims[:, [1, 0, 2]]
box_list = [] box_list = []
for i in range(len(box3d)): for i in range(len(box3d)):
...@@ -602,7 +607,7 @@ def output_to_nusc_box(detection): ...@@ -602,7 +607,7 @@ def output_to_nusc_box(detection):
# velo_val * np.cos(velo_ori), velo_val * np.sin(velo_ori), 0.0) # velo_val * np.cos(velo_ori), velo_val * np.sin(velo_ori), 0.0)
box = NuScenesBox( box = NuScenesBox(
box_gravity_center[i], box_gravity_center[i],
box_dims[i], nus_box_dims[i],
quat, quat,
label=labels[i], label=labels[i],
score=scores[i], score=scores[i],
...@@ -624,7 +629,7 @@ def lidar_nusc_box_to_global(info, ...@@ -624,7 +629,7 @@ def lidar_nusc_box_to_global(info,
boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes. boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes.
classes (list[str]): Mapped classes in the evaluation. classes (list[str]): Mapped classes in the evaluation.
eval_configs (object): Evaluation configuration object. eval_configs (object): Evaluation configuration object.
eval_version (str): Evaluation version. eval_version (str, optional): Evaluation version.
Default: 'detection_cvpr_2019' Default: 'detection_cvpr_2019'
Returns: Returns:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import tempfile
import warnings
from os import path as osp
import mmcv import mmcv
import numpy as np import numpy as np
import pyquaternion import pyquaternion
import tempfile
import torch import torch
import warnings
from nuscenes.utils.data_classes import Box as NuScenesBox from nuscenes.utils.data_classes import Box as NuScenesBox
from os import path as osp
from mmdet3d.core import bbox3d2result, box3d_multiclass_nms, xywhr2xyxyr from mmdet3d.core import bbox3d2result, box3d_multiclass_nms, xywhr2xyxyr
from mmdet.datasets import DATASETS, CocoDataset from mmdet.datasets import DATASETS, CocoDataset
...@@ -44,8 +45,9 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -44,8 +45,9 @@ class NuScenesMonoDataset(CocoDataset):
- 'Camera': Box in camera coordinates. - 'Camera': Box in camera coordinates.
eval_version (str, optional): Configuration version of evaluation. eval_version (str, optional): Configuration version of evaluation.
Defaults to 'detection_cvpr_2019'. Defaults to 'detection_cvpr_2019'.
use_valid_flag (bool): Whether to use `use_valid_flag` key in the info use_valid_flag (bool, optional): Whether to use `use_valid_flag` key
file as mask to filter gt_boxes and gt_names. Defaults to False. in the info file as mask to filter gt_boxes and gt_names.
Defaults to False.
version (str, optional): Dataset version. Defaults to 'v1.0-trainval'. version (str, optional): Dataset version. Defaults to 'v1.0-trainval'.
""" """
CLASSES = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle', CLASSES = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle',
...@@ -140,8 +142,8 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -140,8 +142,8 @@ class NuScenesMonoDataset(CocoDataset):
ann_info (list[dict]): Annotation info of an image. ann_info (list[dict]): Annotation info of an image.
Returns: Returns:
dict: A dict containing the following keys: bboxes, labels, \ dict: A dict containing the following keys: bboxes, labels,
gt_bboxes_3d, gt_labels_3d, attr_labels, centers2d, \ gt_bboxes_3d, gt_labels_3d, attr_labels, centers2d,
depths, bboxes_ignore, masks, seg_map depths, bboxes_ignore, masks, seg_map
""" """
gt_bboxes = [] gt_bboxes = []
...@@ -394,10 +396,11 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -394,10 +396,11 @@ class NuScenesMonoDataset(CocoDataset):
Args: Args:
result_path (str): Path of the result file. result_path (str): Path of the result file.
logger (logging.Logger | str | None): Logger used for printing logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
metric (str): Metric name used for evaluation. Default: 'bbox'. metric (str, optional): Metric name used for evaluation.
result_name (str): Result name in the metric prefix. Default: 'bbox'.
result_name (str, optional): Result name in the metric prefix.
Default: 'img_bbox'. Default: 'img_bbox'.
Returns: Returns:
...@@ -448,13 +451,13 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -448,13 +451,13 @@ class NuScenesMonoDataset(CocoDataset):
Args: Args:
results (list[tuple | numpy.ndarray]): Testing results of the results (list[tuple | numpy.ndarray]): Testing results of the
dataset. dataset.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
Returns: Returns:
tuple: (result_files, tmp_dir), result_files is a dict containing \ tuple: (result_files, tmp_dir), result_files is a dict containing
the json filepaths, tmp_dir is the temporal directory created \ the json filepaths, tmp_dir is the temporal directory created
for saving json files when jsonfile_prefix is not specified. for saving json files when jsonfile_prefix is not specified.
""" """
assert isinstance(results, list), 'results must be a list' assert isinstance(results, list), 'results must be a list'
...@@ -504,15 +507,18 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -504,15 +507,18 @@ class NuScenesMonoDataset(CocoDataset):
Args: Args:
results (list[dict]): Testing results of the dataset. results (list[dict]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str], optional): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing Default: 'bbox'.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
show (bool): Whether to visualize. result_names (list[str], optional): Result names in the
metric prefix. Default: ['img_bbox'].
show (bool, optional): Whether to visualize.
Default: False. Default: False.
out_dir (str): Path to save the visualization results. out_dir (str, optional): Path to save the visualization results.
Default: None. Default: None.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
...@@ -535,7 +541,7 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -535,7 +541,7 @@ class NuScenesMonoDataset(CocoDataset):
if tmp_dir is not None: if tmp_dir is not None:
tmp_dir.cleanup() tmp_dir.cleanup()
if show: if show or out_dir:
self.show(results, out_dir, pipeline=pipeline) self.show(results, out_dir, pipeline=pipeline)
return results_dict return results_dict
...@@ -576,7 +582,7 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -576,7 +582,7 @@ class NuScenesMonoDataset(CocoDataset):
"""Get data loading pipeline in self.show/evaluate function. """Get data loading pipeline in self.show/evaluate function.
Args: Args:
pipeline (list[dict] | None): Input pipeline. If None is given, \ pipeline (list[dict]): Input pipeline. If None is given,
get from self.pipeline. get from self.pipeline.
""" """
if pipeline is None: if pipeline is None:
...@@ -601,13 +607,14 @@ class NuScenesMonoDataset(CocoDataset): ...@@ -601,13 +607,14 @@ class NuScenesMonoDataset(CocoDataset):
] ]
return Compose(pipeline) return Compose(pipeline)
def show(self, results, out_dir, show=True, pipeline=None): def show(self, results, out_dir, show=False, pipeline=None):
"""Results visualization. """Results visualization.
Args: Args:
results (list[dict]): List of bounding boxes results. results (list[dict]): List of bounding boxes results.
out_dir (str): Output directory of visualization result. out_dir (str): Output directory of visualization result.
show (bool): Visualize the results online. show (bool): Whether to visualize the results online.
Default: False.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
""" """
...@@ -696,7 +703,7 @@ def cam_nusc_box_to_global(info, ...@@ -696,7 +703,7 @@ def cam_nusc_box_to_global(info,
boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes. boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes.
classes (list[str]): Mapped classes in the evaluation. classes (list[str]): Mapped classes in the evaluation.
eval_configs (object): Evaluation configuration object. eval_configs (object): Evaluation configuration object.
eval_version (str): Evaluation version. eval_version (str, optional): Evaluation version.
Default: 'detection_cvpr_2019' Default: 'detection_cvpr_2019'
Returns: Returns:
...@@ -736,7 +743,7 @@ def global_nusc_box_to_cam(info, ...@@ -736,7 +743,7 @@ def global_nusc_box_to_cam(info,
boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes. boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes.
classes (list[str]): Mapped classes in the evaluation. classes (list[str]): Mapped classes in the evaluation.
eval_configs (object): Evaluation configuration object. eval_configs (object): Evaluation configuration object.
eval_version (str): Evaluation version. eval_version (str, optional): Evaluation version.
Default: 'detection_cvpr_2019' Default: 'detection_cvpr_2019'
Returns: Returns:
...@@ -769,7 +776,7 @@ def nusc_box_to_cam_box3d(boxes): ...@@ -769,7 +776,7 @@ def nusc_box_to_cam_box3d(boxes):
boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes. boxes (list[:obj:`NuScenesBox`]): List of predicted NuScenesBoxes.
Returns: Returns:
tuple (:obj:`CameraInstance3DBoxes` | torch.Tensor | torch.Tensor): \ tuple (:obj:`CameraInstance3DBoxes` | torch.Tensor | torch.Tensor):
Converted 3D bounding boxes, scores and labels. Converted 3D bounding boxes, scores and labels.
""" """
locs = torch.Tensor([b.center for b in boxes]).view(-1, 3) locs = torch.Tensor([b.center for b in boxes]).view(-1, 3)
......
...@@ -3,17 +3,19 @@ from mmdet.datasets.pipelines import Compose ...@@ -3,17 +3,19 @@ from mmdet.datasets.pipelines import Compose
from .dbsampler import DataBaseSampler from .dbsampler import DataBaseSampler
from .formating import Collect3D, DefaultFormatBundle, DefaultFormatBundle3D from .formating import Collect3D, DefaultFormatBundle, DefaultFormatBundle3D
from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D, from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D,
LoadMultiViewImageFromFiles, LoadPointsFromFile, LoadMultiViewImageFromFiles, LoadPointsFromDict,
LoadPointsFromMultiSweeps, NormalizePointsColor, LoadPointsFromFile, LoadPointsFromMultiSweeps,
PointSegClassMapping) NormalizePointsColor, PointSegClassMapping)
from .test_time_aug import MultiScaleFlipAug3D from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment, # yapf: disable
GlobalRotScaleTrans, IndoorPatchPointSample, from .transforms_3d import (AffineResize, BackgroundPointsFilter,
IndoorPointSample, ObjectNameFilter, ObjectNoise, GlobalAlignment, GlobalRotScaleTrans,
ObjectRangeFilter, ObjectSample, PointSample, IndoorPatchPointSample, IndoorPointSample,
PointShuffle, PointsRangeFilter, ObjectNameFilter, ObjectNoise, ObjectRangeFilter,
RandomDropPointsColor, RandomFlip3D, ObjectSample, PointSample, PointShuffle,
RandomJitterPoints, VoxelBasedPointSampler) PointsRangeFilter, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints, RandomShiftScale,
VoxelBasedPointSampler)
__all__ = [ __all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
...@@ -25,5 +27,6 @@ __all__ = [ ...@@ -25,5 +27,6 @@ __all__ = [
'LoadPointsFromMultiSweeps', 'BackgroundPointsFilter', 'LoadPointsFromMultiSweeps', 'BackgroundPointsFilter',
'VoxelBasedPointSampler', 'GlobalAlignment', 'IndoorPatchPointSample', 'VoxelBasedPointSampler', 'GlobalAlignment', 'IndoorPatchPointSample',
'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor', 'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor',
'RandomJitterPoints' 'RandomJitterPoints', 'AffineResize', 'RandomShiftScale',
'LoadPointsFromDict'
] ]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numba import numba
import numpy as np import numpy as np
import warnings from numba.core.errors import NumbaPerformanceWarning
from numba.errors import NumbaPerformanceWarning
from mmdet3d.core.bbox import box_np_ops from mmdet3d.core.bbox import box_np_ops
...@@ -21,8 +22,8 @@ def _rotation_box2d_jit_(corners, angle, rot_mat_T): ...@@ -21,8 +22,8 @@ def _rotation_box2d_jit_(corners, angle, rot_mat_T):
rot_sin = np.sin(angle) rot_sin = np.sin(angle)
rot_cos = np.cos(angle) rot_cos = np.cos(angle)
rot_mat_T[0, 0] = rot_cos rot_mat_T[0, 0] = rot_cos
rot_mat_T[0, 1] = -rot_sin rot_mat_T[0, 1] = rot_sin
rot_mat_T[1, 0] = rot_sin rot_mat_T[1, 0] = -rot_sin
rot_mat_T[1, 1] = rot_cos rot_mat_T[1, 1] = rot_cos
corners[:] = corners @ rot_mat_T corners[:] = corners @ rot_mat_T
...@@ -34,8 +35,8 @@ def box_collision_test(boxes, qboxes, clockwise=True): ...@@ -34,8 +35,8 @@ def box_collision_test(boxes, qboxes, clockwise=True):
Args: Args:
boxes (np.ndarray): Corners of current boxes. boxes (np.ndarray): Corners of current boxes.
qboxes (np.ndarray): Boxes to be avoid colliding. qboxes (np.ndarray): Boxes to be avoid colliding.
clockwise (bool): Whether the corners are in clockwise order. clockwise (bool, optional): Whether the corners are in
Default: True. clockwise order. Default: True.
""" """
N = boxes.shape[0] N = boxes.shape[0]
K = qboxes.shape[0] K = qboxes.shape[0]
...@@ -211,8 +212,8 @@ def noise_per_box_v2_(boxes, valid_mask, loc_noises, rot_noises, ...@@ -211,8 +212,8 @@ def noise_per_box_v2_(boxes, valid_mask, loc_noises, rot_noises,
rot_sin = np.sin(current_box[0, -1]) rot_sin = np.sin(current_box[0, -1])
rot_cos = np.cos(current_box[0, -1]) rot_cos = np.cos(current_box[0, -1])
rot_mat_T[0, 0] = rot_cos rot_mat_T[0, 0] = rot_cos
rot_mat_T[0, 1] = -rot_sin rot_mat_T[0, 1] = rot_sin
rot_mat_T[1, 0] = rot_sin rot_mat_T[1, 0] = -rot_sin
rot_mat_T[1, 1] = rot_cos rot_mat_T[1, 1] = rot_cos
current_corners[:] = current_box[ current_corners[:] = current_box[
0, 2:4] * corners_norm @ rot_mat_T + current_box[0, :2] 0, 2:4] * corners_norm @ rot_mat_T + current_box[0, :2]
...@@ -264,18 +265,18 @@ def _rotation_matrix_3d_(rot_mat_T, angle, axis): ...@@ -264,18 +265,18 @@ def _rotation_matrix_3d_(rot_mat_T, angle, axis):
rot_mat_T[:] = np.eye(3) rot_mat_T[:] = np.eye(3)
if axis == 1: if axis == 1:
rot_mat_T[0, 0] = rot_cos rot_mat_T[0, 0] = rot_cos
rot_mat_T[0, 2] = -rot_sin rot_mat_T[0, 2] = rot_sin
rot_mat_T[2, 0] = rot_sin rot_mat_T[2, 0] = -rot_sin
rot_mat_T[2, 2] = rot_cos rot_mat_T[2, 2] = rot_cos
elif axis == 2 or axis == -1: elif axis == 2 or axis == -1:
rot_mat_T[0, 0] = rot_cos rot_mat_T[0, 0] = rot_cos
rot_mat_T[0, 1] = -rot_sin rot_mat_T[0, 1] = rot_sin
rot_mat_T[1, 0] = rot_sin rot_mat_T[1, 0] = -rot_sin
rot_mat_T[1, 1] = rot_cos rot_mat_T[1, 1] = rot_cos
elif axis == 0: elif axis == 0:
rot_mat_T[1, 1] = rot_cos rot_mat_T[1, 1] = rot_cos
rot_mat_T[1, 2] = -rot_sin rot_mat_T[1, 2] = rot_sin
rot_mat_T[2, 1] = rot_sin rot_mat_T[2, 1] = -rot_sin
rot_mat_T[2, 2] = rot_cos rot_mat_T[2, 2] = rot_cos
...@@ -317,7 +318,7 @@ def box3d_transform_(boxes, loc_transform, rot_transform, valid_mask): ...@@ -317,7 +318,7 @@ def box3d_transform_(boxes, loc_transform, rot_transform, valid_mask):
boxes (np.ndarray): 3D boxes to be transformed. boxes (np.ndarray): 3D boxes to be transformed.
loc_transform (np.ndarray): Location transform to be applied. loc_transform (np.ndarray): Location transform to be applied.
rot_transform (np.ndarray): Rotation transform to be applied. rot_transform (np.ndarray): Rotation transform to be applied.
valid_mask (np.ndarray | None): Mask to indicate which boxes are valid. valid_mask (np.ndarray): Mask to indicate which boxes are valid.
""" """
num_box = boxes.shape[0] num_box = boxes.shape[0]
for i in range(num_box): for i in range(num_box):
...@@ -338,16 +339,17 @@ def noise_per_object_v3_(gt_boxes, ...@@ -338,16 +339,17 @@ def noise_per_object_v3_(gt_boxes,
Args: Args:
gt_boxes (np.ndarray): Ground truth boxes with shape (N, 7). gt_boxes (np.ndarray): Ground truth boxes with shape (N, 7).
points (np.ndarray | None): Input point cloud with shape (M, 4). points (np.ndarray, optional): Input point cloud with
Default: None. shape (M, 4). Default: None.
valid_mask (np.ndarray | None): Mask to indicate which boxes are valid. valid_mask (np.ndarray, optional): Mask to indicate which
Default: None. boxes are valid. Default: None.
rotation_perturb (float): Rotation perturbation. Default: pi / 4. rotation_perturb (float, optional): Rotation perturbation.
center_noise_std (float): Center noise standard deviation. Default: pi / 4.
center_noise_std (float, optional): Center noise standard deviation.
Default: 1.0. Default: 1.0.
global_random_rot_range (float): Global random rotation range. global_random_rot_range (float, optional): Global random rotation
Default: pi/4. range. Default: pi/4.
num_try (int): Number of try. Default: 100. num_try (int, optional): Number of try. Default: 100.
""" """
num_boxes = gt_boxes.shape[0] num_boxes = gt_boxes.shape[0]
if not isinstance(rotation_perturb, (list, tuple, np.ndarray)): if not isinstance(rotation_perturb, (list, tuple, np.ndarray)):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import os
import mmcv import mmcv
import numpy as np import numpy as np
import os
from mmdet3d.core.bbox import box_np_ops from mmdet3d.core.bbox import box_np_ops
from mmdet3d.datasets.pipelines import data_augment_utils from mmdet3d.datasets.pipelines import data_augment_utils
...@@ -15,10 +16,10 @@ class BatchSampler: ...@@ -15,10 +16,10 @@ class BatchSampler:
Args: Args:
sample_list (list[dict]): List of samples. sample_list (list[dict]): List of samples.
name (str | None): The category of samples. Default: None. name (str, optional): The category of samples. Default: None.
epoch (int | None): Sampling epoch. Default: None. epoch (int, optional): Sampling epoch. Default: None.
shuffle (bool): Whether to shuffle indices. Default: False. shuffle (bool, optional): Whether to shuffle indices. Default: False.
drop_reminder (bool): Drop reminder. Default: False. drop_reminder (bool, optional): Drop reminder. Default: False.
""" """
def __init__(self, def __init__(self,
...@@ -87,9 +88,9 @@ class DataBaseSampler(object): ...@@ -87,9 +88,9 @@ class DataBaseSampler(object):
rate (float): Rate of actual sampled over maximum sampled number. rate (float): Rate of actual sampled over maximum sampled number.
prepare (dict): Name of preparation functions and the input value. prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers. sample_groups (dict): Sampled classes and numbers.
classes (list[str]): List of classes. Default: None. classes (list[str], optional): List of classes. Default: None.
points_loader(dict): Config of points loader. Default: dict( points_loader(dict, optional): Config of points loader. Default:
type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3]) dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3])
""" """
def __init__(self, def __init__(self,
...@@ -188,7 +189,7 @@ class DataBaseSampler(object): ...@@ -188,7 +189,7 @@ class DataBaseSampler(object):
db_infos[name] = filtered_infos db_infos[name] = filtered_infos
return db_infos return db_infos
def sample_all(self, gt_bboxes, gt_labels, img=None): def sample_all(self, gt_bboxes, gt_labels, img=None, ground_plane=None):
"""Sampling all categories of bboxes. """Sampling all categories of bboxes.
Args: Args:
...@@ -198,9 +199,9 @@ class DataBaseSampler(object): ...@@ -198,9 +199,9 @@ class DataBaseSampler(object):
Returns: Returns:
dict: Dict of sampled 'pseudo ground truths'. dict: Dict of sampled 'pseudo ground truths'.
- gt_labels_3d (np.ndarray): ground truths labels \ - gt_labels_3d (np.ndarray): ground truths labels
of sampled objects. of sampled objects.
- gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): \ - gt_bboxes_3d (:obj:`BaseInstance3DBoxes`):
sampled ground truth 3D bounding boxes sampled ground truth 3D bounding boxes
- points (np.ndarray): sampled points - points (np.ndarray): sampled points
- group_ids (np.ndarray): ids of sampled ground truths - group_ids (np.ndarray): ids of sampled ground truths
...@@ -263,6 +264,15 @@ class DataBaseSampler(object): ...@@ -263,6 +264,15 @@ class DataBaseSampler(object):
gt_labels = np.array([self.cat2label[s['name']] for s in sampled], gt_labels = np.array([self.cat2label[s['name']] for s in sampled],
dtype=np.long) dtype=np.long)
if ground_plane is not None:
xyz = sampled_gt_bboxes[:, :3]
dz = (ground_plane[:3][None, :] *
xyz).sum(-1) + ground_plane[3]
sampled_gt_bboxes[:, 2] -= dz
for i, s_points in enumerate(s_points_list):
s_points.tensor[:, 2].sub_(dz[i])
ret = { ret = {
'gt_labels_3d': 'gt_labels_3d':
gt_labels, gt_labels,
......
...@@ -24,7 +24,7 @@ class DefaultFormatBundle(object): ...@@ -24,7 +24,7 @@ class DefaultFormatBundle(object):
- gt_bboxes_ignore: (1)to tensor, (2)to DataContainer - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
- gt_labels: (1)to tensor, (2)to DataContainer - gt_labels: (1)to tensor, (2)to DataContainer
- gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True)
- gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \ - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
(3)to DataContainer (stack=True) (3)to DataContainer (stack=True)
""" """
...@@ -92,8 +92,8 @@ class Collect3D(object): ...@@ -92,8 +92,8 @@ class Collect3D(object):
The "img_meta" item is always populated. The contents of the "img_meta" The "img_meta" item is always populated. The contents of the "img_meta"
dictionary depends on "meta_keys". By default this includes: dictionary depends on "meta_keys". By default this includes:
- 'img_shape': shape of the image input to the network as a tuple \ - 'img_shape': shape of the image input to the network as a tuple
(h, w, c). Note that images may be zero padded on the \ (h, w, c). Note that images may be zero padded on the
bottom/right if the batch tensor is larger than this shape. bottom/right if the batch tensor is larger than this shape.
- 'scale_factor': a float indicating the preprocessing scale - 'scale_factor': a float indicating the preprocessing scale
- 'flip': a boolean indicating if image flip transform was used - 'flip': a boolean indicating if image flip transform was used
...@@ -103,9 +103,9 @@ class Collect3D(object): ...@@ -103,9 +103,9 @@ class Collect3D(object):
- 'lidar2img': transform from lidar to image - 'lidar2img': transform from lidar to image
- 'depth2img': transform from depth to image - 'depth2img': transform from depth to image
- 'cam2img': transform from camera to image - 'cam2img': transform from camera to image
- 'pcd_horizontal_flip': a boolean indicating if point cloud is \ - 'pcd_horizontal_flip': a boolean indicating if point cloud is
flipped horizontally flipped horizontally
- 'pcd_vertical_flip': a boolean indicating if point cloud is \ - 'pcd_vertical_flip': a boolean indicating if point cloud is
flipped vertically flipped vertically
- 'box_mode_3d': 3D box mode - 'box_mode_3d': 3D box mode
- 'box_type_3d': 3D box type - 'box_type_3d': 3D box type
...@@ -130,15 +130,16 @@ class Collect3D(object): ...@@ -130,15 +130,16 @@ class Collect3D(object):
'sample_idx', 'pcd_scale_factor', 'pcd_rotation', 'pts_filename') 'sample_idx', 'pcd_scale_factor', 'pcd_rotation', 'pts_filename')
""" """
def __init__(self, def __init__(
keys, self,
meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img', keys,
'depth2img', 'cam2img', 'pad_shape', meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img',
'scale_factor', 'flip', 'pcd_horizontal_flip', 'depth2img', 'cam2img', 'pad_shape', 'scale_factor', 'flip',
'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d', 'pcd_horizontal_flip', 'pcd_vertical_flip', 'box_mode_3d',
'img_norm_cfg', 'pcd_trans', 'sample_idx', 'box_type_3d', 'img_norm_cfg', 'pcd_trans', 'sample_idx',
'pcd_scale_factor', 'pcd_rotation', 'pts_filename', 'pcd_scale_factor', 'pcd_rotation', 'pcd_rotation_angle',
'transformation_3d_flow')): 'pts_filename', 'transformation_3d_flow', 'trans_mat',
'affine_aug')):
self.keys = keys self.keys = keys
self.meta_keys = meta_keys self.meta_keys = meta_keys
......
...@@ -14,9 +14,10 @@ class LoadMultiViewImageFromFiles(object): ...@@ -14,9 +14,10 @@ class LoadMultiViewImageFromFiles(object):
Expects results['img_filename'] to be a list of filenames. Expects results['img_filename'] to be a list of filenames.
Args: Args:
to_float32 (bool): Whether to convert the img to float32. to_float32 (bool, optional): Whether to convert the img to float32.
Defaults to False. Defaults to False.
color_type (str): Color type of the file. Defaults to 'unchanged'. color_type (str, optional): Color type of the file.
Defaults to 'unchanged'.
""" """
def __init__(self, to_float32=False, color_type='unchanged'): def __init__(self, to_float32=False, color_type='unchanged'):
...@@ -30,7 +31,7 @@ class LoadMultiViewImageFromFiles(object): ...@@ -30,7 +31,7 @@ class LoadMultiViewImageFromFiles(object):
results (dict): Result dict containing multi-view image filenames. results (dict): Result dict containing multi-view image filenames.
Returns: Returns:
dict: The result dict containing the multi-view image data. \ dict: The result dict containing the multi-view image data.
Added keys and values are described below. Added keys and values are described below.
- filename (str): Multi-view image filenames. - filename (str): Multi-view image filenames.
...@@ -48,7 +49,7 @@ class LoadMultiViewImageFromFiles(object): ...@@ -48,7 +49,7 @@ class LoadMultiViewImageFromFiles(object):
if self.to_float32: if self.to_float32:
img = img.astype(np.float32) img = img.astype(np.float32)
results['filename'] = filename results['filename'] = filename
# unravel to list, see `DefaultFormatBundle` in formating.py # unravel to list, see `DefaultFormatBundle` in formatting.py
# which will transpose each image separately and then stack into array # which will transpose each image separately and then stack into array
results['img'] = [img[..., i] for i in range(img.shape[-1])] results['img'] = [img[..., i] for i in range(img.shape[-1])]
results['img_shape'] = img.shape results['img_shape'] = img.shape
...@@ -77,7 +78,7 @@ class LoadImageFromFileMono3D(LoadImageFromFile): ...@@ -77,7 +78,7 @@ class LoadImageFromFileMono3D(LoadImageFromFile):
detection, additional camera parameters need to be loaded. detection, additional camera parameters need to be loaded.
Args: Args:
kwargs (dict): Arguments are the same as those in \ kwargs (dict): Arguments are the same as those in
:class:`LoadImageFromFile`. :class:`LoadImageFromFile`.
""" """
...@@ -102,17 +103,20 @@ class LoadPointsFromMultiSweeps(object): ...@@ -102,17 +103,20 @@ class LoadPointsFromMultiSweeps(object):
This is usually used for nuScenes dataset to utilize previous sweeps. This is usually used for nuScenes dataset to utilize previous sweeps.
Args: Args:
sweeps_num (int): Number of sweeps. Defaults to 10. sweeps_num (int, optional): Number of sweeps. Defaults to 10.
load_dim (int): Dimension number of the loaded points. Defaults to 5. load_dim (int, optional): Dimension number of the loaded points.
use_dim (list[int]): Which dimension to use. Defaults to [0, 1, 2, 4]. Defaults to 5.
file_client_args (dict): Config dict of file clients, refer to use_dim (list[int], optional): Which dimension to use.
Defaults to [0, 1, 2, 4].
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details. Defaults to dict(backend='disk'). for more details. Defaults to dict(backend='disk').
pad_empty_sweeps (bool): Whether to repeat keyframe when pad_empty_sweeps (bool, optional): Whether to repeat keyframe when
sweeps is empty. Defaults to False. sweeps is empty. Defaults to False.
remove_close (bool): Whether to remove close points. remove_close (bool, optional): Whether to remove close points.
Defaults to False. Defaults to False.
test_mode (bool): If test_model=True used for testing, it will not test_mode (bool, optional): If `test_mode=True`, it will not
randomly sample sweeps but select the nearest N frames. randomly sample sweeps but select the nearest N frames.
Defaults to False. Defaults to False.
""" """
...@@ -161,7 +165,7 @@ class LoadPointsFromMultiSweeps(object): ...@@ -161,7 +165,7 @@ class LoadPointsFromMultiSweeps(object):
Args: Args:
points (np.ndarray | :obj:`BasePoints`): Sweep points. points (np.ndarray | :obj:`BasePoints`): Sweep points.
radius (float): Radius below which points are removed. radius (float, optional): Radius below which points are removed.
Defaults to 1.0. Defaults to 1.0.
Returns: Returns:
...@@ -182,14 +186,14 @@ class LoadPointsFromMultiSweeps(object): ...@@ -182,14 +186,14 @@ class LoadPointsFromMultiSweeps(object):
"""Call function to load multi-sweep point clouds from files. """Call function to load multi-sweep point clouds from files.
Args: Args:
results (dict): Result dict containing multi-sweep point cloud \ results (dict): Result dict containing multi-sweep point cloud
filenames. filenames.
Returns: Returns:
dict: The result dict containing the multi-sweep points data. \ dict: The result dict containing the multi-sweep points data.
Added key and value are described below. Added key and value are described below.
- points (np.ndarray | :obj:`BasePoints`): Multi-sweep point \ - points (np.ndarray | :obj:`BasePoints`): Multi-sweep point
cloud arrays. cloud arrays.
""" """
points = results['points'] points = results['points']
...@@ -243,8 +247,8 @@ class PointSegClassMapping(object): ...@@ -243,8 +247,8 @@ class PointSegClassMapping(object):
Args: Args:
valid_cat_ids (tuple[int]): A tuple of valid category. valid_cat_ids (tuple[int]): A tuple of valid category.
max_cat_id (int): The max possible cat_id in input segmentation mask. max_cat_id (int, optional): The max possible cat_id in input
Defaults to 40. segmentation mask. Defaults to 40.
""" """
def __init__(self, valid_cat_ids, max_cat_id=40): def __init__(self, valid_cat_ids, max_cat_id=40):
...@@ -268,7 +272,7 @@ class PointSegClassMapping(object): ...@@ -268,7 +272,7 @@ class PointSegClassMapping(object):
results (dict): Result dict containing point semantic masks. results (dict): Result dict containing point semantic masks.
Returns: Returns:
dict: The result dict containing the mapped category ids. \ dict: The result dict containing the mapped category ids.
Updated key and value are described below. Updated key and value are described below.
- pts_semantic_mask (np.ndarray): Mapped semantic masks. - pts_semantic_mask (np.ndarray): Mapped semantic masks.
...@@ -307,7 +311,7 @@ class NormalizePointsColor(object): ...@@ -307,7 +311,7 @@ class NormalizePointsColor(object):
results (dict): Result dict containing point clouds data. results (dict): Result dict containing point clouds data.
Returns: Returns:
dict: The result dict containing the normalized points. \ dict: The result dict containing the normalized points.
Updated key and value are described below. Updated key and value are described below.
- points (:obj:`BasePoints`): Points after color normalization. - points (:obj:`BasePoints`): Points after color normalization.
...@@ -334,7 +338,7 @@ class NormalizePointsColor(object): ...@@ -334,7 +338,7 @@ class NormalizePointsColor(object):
class LoadPointsFromFile(object): class LoadPointsFromFile(object):
"""Load Points From File. """Load Points From File.
Load sunrgbd and scannet points from file. Load points from file.
Args: Args:
coord_type (str): The type of coordinates of points cloud. coord_type (str): The type of coordinates of points cloud.
...@@ -342,14 +346,17 @@ class LoadPointsFromFile(object): ...@@ -342,14 +346,17 @@ class LoadPointsFromFile(object):
- 'LIDAR': Points in LiDAR coordinates. - 'LIDAR': Points in LiDAR coordinates.
- 'DEPTH': Points in depth coordinates, usually for indoor dataset. - 'DEPTH': Points in depth coordinates, usually for indoor dataset.
- 'CAMERA': Points in camera coordinates. - 'CAMERA': Points in camera coordinates.
load_dim (int): The dimension of the loaded points. load_dim (int, optional): The dimension of the loaded points.
Defaults to 6. Defaults to 6.
use_dim (list[int]): Which dimensions of the points to be used. use_dim (list[int], optional): Which dimensions of the points to use.
Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4 Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
or use_dim=[0, 1, 2, 3] to use the intensity dimension. or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool): Whether to use shifted height. Defaults to False. shift_height (bool, optional): Whether to use shifted height.
use_color (bool): Whether to use color features. Defaults to False. Defaults to False.
file_client_args (dict): Config dict of file clients, refer to use_color (bool, optional): Whether to use color features.
Defaults to False.
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details. Defaults to dict(backend='disk'). for more details. Defaults to dict(backend='disk').
""" """
...@@ -405,7 +412,7 @@ class LoadPointsFromFile(object): ...@@ -405,7 +412,7 @@ class LoadPointsFromFile(object):
results (dict): Result dict containing point clouds data. results (dict): Result dict containing point clouds data.
Returns: Returns:
dict: The result dict containing the point clouds data. \ dict: The result dict containing the point clouds data.
Added key and value are described below. Added key and value are described below.
- points (:obj:`BasePoints`): Point clouds data. - points (:obj:`BasePoints`): Point clouds data.
...@@ -453,6 +460,15 @@ class LoadPointsFromFile(object): ...@@ -453,6 +460,15 @@ class LoadPointsFromFile(object):
return repr_str return repr_str
@PIPELINES.register_module()
class LoadPointsFromDict(LoadPointsFromFile):
"""Load Points From Dict."""
def __call__(self, results):
assert 'points' in results
return results
@PIPELINES.register_module() @PIPELINES.register_module()
class LoadAnnotations3D(LoadAnnotations): class LoadAnnotations3D(LoadAnnotations):
"""Load Annotations3D. """Load Annotations3D.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import warnings import warnings
from copy import deepcopy from copy import deepcopy
import mmcv
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import Compose from mmdet.datasets.pipelines import Compose
...@@ -16,18 +17,19 @@ class MultiScaleFlipAug3D(object): ...@@ -16,18 +17,19 @@ class MultiScaleFlipAug3D(object):
img_scale (tuple | list[tuple]: Images scales for resizing. img_scale (tuple | list[tuple]: Images scales for resizing.
pts_scale_ratio (float | list[float]): Points scale ratios for pts_scale_ratio (float | list[float]): Points scale ratios for
resizing. resizing.
flip (bool): Whether apply flip augmentation. Defaults to False. flip (bool, optional): Whether apply flip augmentation.
flip_direction (str | list[str]): Flip augmentation directions Defaults to False.
for images, options are "horizontal" and "vertical". flip_direction (str | list[str], optional): Flip augmentation
directions for images, options are "horizontal" and "vertical".
If flip_direction is list, multiple flip augmentations will If flip_direction is list, multiple flip augmentations will
be applied. It has no effect when ``flip == False``. be applied. It has no effect when ``flip == False``.
Defaults to "horizontal". Defaults to "horizontal".
pcd_horizontal_flip (bool): Whether apply horizontal flip augmentation pcd_horizontal_flip (bool, optional): Whether apply horizontal
to point cloud. Defaults to True. Note that it works only when flip augmentation to point cloud. Defaults to True.
'flip' is turned on. Note that it works only when 'flip' is turned on.
pcd_vertical_flip (bool): Whether apply vertical flip augmentation pcd_vertical_flip (bool, optional): Whether apply vertical flip
to point cloud. Defaults to True. Note that it works only when augmentation to point cloud. Defaults to True.
'flip' is turned on. Note that it works only when 'flip' is turned on.
""" """
def __init__(self, def __init__(self,
...@@ -70,7 +72,7 @@ class MultiScaleFlipAug3D(object): ...@@ -70,7 +72,7 @@ class MultiScaleFlipAug3D(object):
results (dict): Result dict contains the data to augment. results (dict): Result dict contains the data to augment.
Returns: Returns:
dict: The result dict contains the data that is augmented with \ dict: The result dict contains the data that is augmented with
different scales and flips. different scales and flips.
""" """
aug_data = [] aug_data = []
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np import random
import warnings import warnings
import cv2
import numpy as np
from mmcv import is_tuple_of from mmcv import is_tuple_of
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
...@@ -22,7 +25,7 @@ class RandomDropPointsColor(object): ...@@ -22,7 +25,7 @@ class RandomDropPointsColor(object):
util/transform.py#L223>`_ for more details. util/transform.py#L223>`_ for more details.
Args: Args:
drop_ratio (float): The probability of dropping point colors. drop_ratio (float, optional): The probability of dropping point colors.
Defaults to 0.2. Defaults to 0.2.
""" """
...@@ -38,7 +41,7 @@ class RandomDropPointsColor(object): ...@@ -38,7 +41,7 @@ class RandomDropPointsColor(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after color dropping, \ dict: Results after color dropping,
'points' key is updated in the result dict. 'points' key is updated in the result dict.
""" """
points = input_dict['points'] points = input_dict['points']
...@@ -105,10 +108,11 @@ class RandomFlip3D(RandomFlip): ...@@ -105,10 +108,11 @@ class RandomFlip3D(RandomFlip):
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
direction (str): Flip direction. Default: horizontal. direction (str, optional): Flip direction.
Default: 'horizontal'.
Returns: Returns:
dict: Flipped results, 'points', 'bbox3d_fields' keys are \ dict: Flipped results, 'points', 'bbox3d_fields' keys are
updated in the result dict. updated in the result dict.
""" """
assert direction in ['horizontal', 'vertical'] assert direction in ['horizontal', 'vertical']
...@@ -141,15 +145,15 @@ class RandomFlip3D(RandomFlip): ...@@ -141,15 +145,15 @@ class RandomFlip3D(RandomFlip):
input_dict['cam2img'][0][2] = w - input_dict['cam2img'][0][2] input_dict['cam2img'][0][2] = w - input_dict['cam2img'][0][2]
def __call__(self, input_dict): def __call__(self, input_dict):
"""Call function to flip points, values in the ``bbox3d_fields`` and \ """Call function to flip points, values in the ``bbox3d_fields`` and
also flip 2D image and its annotations. also flip 2D image and its annotations.
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Flipped results, 'flip', 'flip_direction', \ dict: Flipped results, 'flip', 'flip_direction',
'pcd_horizontal_flip' and 'pcd_vertical_flip' keys are added \ 'pcd_horizontal_flip' and 'pcd_vertical_flip' keys are added
into result dict. into result dict.
""" """
# flip 2D image and its annotations # flip 2D image and its annotations
...@@ -191,20 +195,20 @@ class RandomFlip3D(RandomFlip): ...@@ -191,20 +195,20 @@ class RandomFlip3D(RandomFlip):
class RandomJitterPoints(object): class RandomJitterPoints(object):
"""Randomly jitter point coordinates. """Randomly jitter point coordinates.
Different from the global translation in ``GlobalRotScaleTrans``, here we \ Different from the global translation in ``GlobalRotScaleTrans``, here we
apply different noises to each point in a scene. apply different noises to each point in a scene.
Args: Args:
jitter_std (list[float]): The standard deviation of jittering noise. jitter_std (list[float]): The standard deviation of jittering noise.
This applies random noise to all points in a 3D scene, which is \ This applies random noise to all points in a 3D scene, which is
sampled from a gaussian distribution whose standard deviation is \ sampled from a gaussian distribution whose standard deviation is
set by ``jitter_std``. Defaults to [0.01, 0.01, 0.01] set by ``jitter_std``. Defaults to [0.01, 0.01, 0.01]
clip_range (list[float] | None): Clip the randomly generated jitter \ clip_range (list[float]): Clip the randomly generated jitter
noise into this range. If None is given, don't perform clipping. noise into this range. If None is given, don't perform clipping.
Defaults to [-0.05, 0.05] Defaults to [-0.05, 0.05]
Note: Note:
This transform should only be used in point cloud segmentation tasks \ This transform should only be used in point cloud segmentation tasks
because we don't transform ground-truth bboxes accordingly. because we don't transform ground-truth bboxes accordingly.
For similar transform in detection task, please refer to `ObjectNoise`. For similar transform in detection task, please refer to `ObjectNoise`.
""" """
...@@ -233,7 +237,7 @@ class RandomJitterPoints(object): ...@@ -233,7 +237,7 @@ class RandomJitterPoints(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after adding noise to each point, \ dict: Results after adding noise to each point,
'points' key is updated in the result dict. 'points' key is updated in the result dict.
""" """
points = input_dict['points'] points = input_dict['points']
...@@ -264,14 +268,17 @@ class ObjectSample(object): ...@@ -264,14 +268,17 @@ class ObjectSample(object):
sample_2d (bool): Whether to also paste 2D image patch to the images sample_2d (bool): Whether to also paste 2D image patch to the images
This should be true when applying multi-modality cut-and-paste. This should be true when applying multi-modality cut-and-paste.
Defaults to False. Defaults to False.
use_ground_plane (bool): Whether to use gound plane to adjust the
3D labels.
""" """
def __init__(self, db_sampler, sample_2d=False): def __init__(self, db_sampler, sample_2d=False, use_ground_plane=False):
self.sampler_cfg = db_sampler self.sampler_cfg = db_sampler
self.sample_2d = sample_2d self.sample_2d = sample_2d
if 'type' not in db_sampler.keys(): if 'type' not in db_sampler.keys():
db_sampler['type'] = 'DataBaseSampler' db_sampler['type'] = 'DataBaseSampler'
self.db_sampler = build_from_cfg(db_sampler, OBJECTSAMPLERS) self.db_sampler = build_from_cfg(db_sampler, OBJECTSAMPLERS)
self.use_ground_plane = use_ground_plane
@staticmethod @staticmethod
def remove_points_in_boxes(points, boxes): def remove_points_in_boxes(points, boxes):
...@@ -295,13 +302,18 @@ class ObjectSample(object): ...@@ -295,13 +302,18 @@ class ObjectSample(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after object sampling augmentation, \ dict: Results after object sampling augmentation,
'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated \ 'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated
in the result dict. in the result dict.
""" """
gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d'] gt_labels_3d = input_dict['gt_labels_3d']
if self.use_ground_plane and 'plane' in input_dict['ann_info']:
ground_plane = input_dict['ann_info']['plane']
input_dict['plane'] = ground_plane
else:
ground_plane = None
# change to float for blending operation # change to float for blending operation
points = input_dict['points'] points = input_dict['points']
if self.sample_2d: if self.sample_2d:
...@@ -315,7 +327,10 @@ class ObjectSample(object): ...@@ -315,7 +327,10 @@ class ObjectSample(object):
img=img) img=img)
else: else:
sampled_dict = self.db_sampler.sample_all( sampled_dict = self.db_sampler.sample_all(
gt_bboxes_3d.tensor.numpy(), gt_labels_3d, img=None) gt_bboxes_3d.tensor.numpy(),
gt_labels_3d,
img=None,
ground_plane=ground_plane)
if sampled_dict is not None: if sampled_dict is not None:
sampled_gt_bboxes_3d = sampled_dict['gt_bboxes_3d'] sampled_gt_bboxes_3d = sampled_dict['gt_bboxes_3d']
...@@ -392,13 +407,13 @@ class ObjectNoise(object): ...@@ -392,13 +407,13 @@ class ObjectNoise(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after adding noise to each object, \ dict: Results after adding noise to each object,
'points', 'gt_bboxes_3d' keys are updated in the result dict. 'points', 'gt_bboxes_3d' keys are updated in the result dict.
""" """
gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_bboxes_3d = input_dict['gt_bboxes_3d']
points = input_dict['points'] points = input_dict['points']
# TODO: check this inplace function # TODO: this is inplace operation
numpy_box = gt_bboxes_3d.tensor.numpy() numpy_box = gt_bboxes_3d.tensor.numpy()
numpy_points = points.tensor.numpy() numpy_points = points.tensor.numpy()
...@@ -432,10 +447,10 @@ class GlobalAlignment(object): ...@@ -432,10 +447,10 @@ class GlobalAlignment(object):
rotation_axis (int): Rotation axis for points and bboxes rotation. rotation_axis (int): Rotation axis for points and bboxes rotation.
Note: Note:
We do not record the applied rotation and translation as in \ We do not record the applied rotation and translation as in
GlobalRotScaleTrans. Because usually, we do not need to reverse \ GlobalRotScaleTrans. Because usually, we do not need to reverse
the alignment step. the alignment step.
For example, ScanNet 3D detection task uses aligned ground-truth \ For example, ScanNet 3D detection task uses aligned ground-truth
bounding boxes for evaluation. bounding boxes for evaluation.
""" """
...@@ -487,7 +502,7 @@ class GlobalAlignment(object): ...@@ -487,7 +502,7 @@ class GlobalAlignment(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after global alignment, 'points' and keys in \ dict: Results after global alignment, 'points' and keys in
input_dict['bbox3d_fields'] are updated in the result dict. input_dict['bbox3d_fields'] are updated in the result dict.
""" """
assert 'axis_align_matrix' in input_dict['ann_info'].keys(), \ assert 'axis_align_matrix' in input_dict['ann_info'].keys(), \
...@@ -516,15 +531,15 @@ class GlobalRotScaleTrans(object): ...@@ -516,15 +531,15 @@ class GlobalRotScaleTrans(object):
"""Apply global rotation, scaling and translation to a 3D scene. """Apply global rotation, scaling and translation to a 3D scene.
Args: Args:
rot_range (list[float]): Range of rotation angle. rot_range (list[float], optional): Range of rotation angle.
Defaults to [-0.78539816, 0.78539816] (close to [-pi/4, pi/4]). Defaults to [-0.78539816, 0.78539816] (close to [-pi/4, pi/4]).
scale_ratio_range (list[float]): Range of scale ratio. scale_ratio_range (list[float], optional): Range of scale ratio.
Defaults to [0.95, 1.05]. Defaults to [0.95, 1.05].
translation_std (list[float]): The standard deviation of translation translation_std (list[float], optional): The standard deviation of
noise. This applies random translation to a scene by a noise, which translation noise applied to a scene, which
is sampled from a gaussian distribution whose standard deviation is sampled from a gaussian distribution whose standard deviation
is set by ``translation_std``. Defaults to [0, 0, 0] is set by ``translation_std``. Defaults to [0, 0, 0]
shift_height (bool): Whether to shift height. shift_height (bool, optional): Whether to shift height.
(the fourth dimension of indoor points) when scaling. (the fourth dimension of indoor points) when scaling.
Defaults to False. Defaults to False.
""" """
...@@ -563,8 +578,8 @@ class GlobalRotScaleTrans(object): ...@@ -563,8 +578,8 @@ class GlobalRotScaleTrans(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after translation, 'points', 'pcd_trans' \ dict: Results after translation, 'points', 'pcd_trans'
and keys in input_dict['bbox3d_fields'] are updated \ and keys in input_dict['bbox3d_fields'] are updated
in the result dict. in the result dict.
""" """
translation_std = np.array(self.translation_std, dtype=np.float32) translation_std = np.array(self.translation_std, dtype=np.float32)
...@@ -582,8 +597,8 @@ class GlobalRotScaleTrans(object): ...@@ -582,8 +597,8 @@ class GlobalRotScaleTrans(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after rotation, 'points', 'pcd_rotation' \ dict: Results after rotation, 'points', 'pcd_rotation'
and keys in input_dict['bbox3d_fields'] are updated \ and keys in input_dict['bbox3d_fields'] are updated
in the result dict. in the result dict.
""" """
rotation = self.rot_range rotation = self.rot_range
...@@ -593,6 +608,7 @@ class GlobalRotScaleTrans(object): ...@@ -593,6 +608,7 @@ class GlobalRotScaleTrans(object):
if len(input_dict['bbox3d_fields']) == 0: if len(input_dict['bbox3d_fields']) == 0:
rot_mat_T = input_dict['points'].rotate(noise_rotation) rot_mat_T = input_dict['points'].rotate(noise_rotation)
input_dict['pcd_rotation'] = rot_mat_T input_dict['pcd_rotation'] = rot_mat_T
input_dict['pcd_rotation_angle'] = noise_rotation
return return
# rotate points with bboxes # rotate points with bboxes
...@@ -602,6 +618,7 @@ class GlobalRotScaleTrans(object): ...@@ -602,6 +618,7 @@ class GlobalRotScaleTrans(object):
noise_rotation, input_dict['points']) noise_rotation, input_dict['points'])
input_dict['points'] = points input_dict['points'] = points
input_dict['pcd_rotation'] = rot_mat_T input_dict['pcd_rotation'] = rot_mat_T
input_dict['pcd_rotation_angle'] = noise_rotation
def _scale_bbox_points(self, input_dict): def _scale_bbox_points(self, input_dict):
"""Private function to scale bounding boxes and points. """Private function to scale bounding boxes and points.
...@@ -610,7 +627,7 @@ class GlobalRotScaleTrans(object): ...@@ -610,7 +627,7 @@ class GlobalRotScaleTrans(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after scaling, 'points'and keys in \ dict: Results after scaling, 'points'and keys in
input_dict['bbox3d_fields'] are updated in the result dict. input_dict['bbox3d_fields'] are updated in the result dict.
""" """
scale = input_dict['pcd_scale_factor'] scale = input_dict['pcd_scale_factor']
...@@ -632,7 +649,7 @@ class GlobalRotScaleTrans(object): ...@@ -632,7 +649,7 @@ class GlobalRotScaleTrans(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after scaling, 'pcd_scale_factor' are updated \ dict: Results after scaling, 'pcd_scale_factor' are updated
in the result dict. in the result dict.
""" """
scale_factor = np.random.uniform(self.scale_ratio_range[0], scale_factor = np.random.uniform(self.scale_ratio_range[0],
...@@ -640,7 +657,7 @@ class GlobalRotScaleTrans(object): ...@@ -640,7 +657,7 @@ class GlobalRotScaleTrans(object):
input_dict['pcd_scale_factor'] = scale_factor input_dict['pcd_scale_factor'] = scale_factor
def __call__(self, input_dict): def __call__(self, input_dict):
"""Private function to rotate, scale and translate bounding boxes and \ """Private function to rotate, scale and translate bounding boxes and
points. points.
Args: Args:
...@@ -648,7 +665,7 @@ class GlobalRotScaleTrans(object): ...@@ -648,7 +665,7 @@ class GlobalRotScaleTrans(object):
Returns: Returns:
dict: Results after scaling, 'points', 'pcd_rotation', dict: Results after scaling, 'points', 'pcd_rotation',
'pcd_scale_factor', 'pcd_trans' and keys in \ 'pcd_scale_factor', 'pcd_trans' and keys in
input_dict['bbox3d_fields'] are updated in the result dict. input_dict['bbox3d_fields'] are updated in the result dict.
""" """
if 'transformation_3d_flow' not in input_dict: if 'transformation_3d_flow' not in input_dict:
...@@ -686,7 +703,7 @@ class PointShuffle(object): ...@@ -686,7 +703,7 @@ class PointShuffle(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after filtering, 'points', 'pts_instance_mask' \ dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
idx = input_dict['points'].shuffle() idx = input_dict['points'].shuffle()
...@@ -725,7 +742,7 @@ class ObjectRangeFilter(object): ...@@ -725,7 +742,7 @@ class ObjectRangeFilter(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \ dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
keys are updated in the result dict. keys are updated in the result dict.
""" """
# Check points instance type and initialise bev_range # Check points instance type and initialise bev_range
...@@ -777,7 +794,7 @@ class PointsRangeFilter(object): ...@@ -777,7 +794,7 @@ class PointsRangeFilter(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after filtering, 'points', 'pts_instance_mask' \ dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
points = input_dict['points'] points = input_dict['points']
...@@ -823,7 +840,7 @@ class ObjectNameFilter(object): ...@@ -823,7 +840,7 @@ class ObjectNameFilter(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \ dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d'
keys are updated in the result dict. keys are updated in the result dict.
""" """
gt_labels_3d = input_dict['gt_labels_3d'] gt_labels_3d = input_dict['gt_labels_3d']
...@@ -891,8 +908,8 @@ class PointSample(object): ...@@ -891,8 +908,8 @@ class PointSample(object):
if sample_range is not None and not replace: if sample_range is not None and not replace:
# Only sampling the near points when len(points) >= num_samples # Only sampling the near points when len(points) >= num_samples
depth = np.linalg.norm(points.tensor, axis=1) depth = np.linalg.norm(points.tensor, axis=1)
far_inds = np.where(depth > sample_range)[0] far_inds = np.where(depth >= sample_range)[0]
near_inds = np.where(depth <= sample_range)[0] near_inds = np.where(depth < sample_range)[0]
# in case there are too many far points # in case there are too many far points
if len(far_inds) > num_samples: if len(far_inds) > num_samples:
far_inds = np.random.choice( far_inds = np.random.choice(
...@@ -915,7 +932,7 @@ class PointSample(object): ...@@ -915,7 +932,7 @@ class PointSample(object):
Args: Args:
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \ dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
points = results['points'] points = results['points']
...@@ -996,10 +1013,10 @@ class IndoorPatchPointSample(object): ...@@ -996,10 +1013,10 @@ class IndoorPatchPointSample(object):
additional features. Defaults to False. additional features. Defaults to False.
num_try (int, optional): Number of times to try if the patch selected num_try (int, optional): Number of times to try if the patch selected
is invalid. Defaults to 10. is invalid. Defaults to 10.
enlarge_size (float | None, optional): Enlarge the sampled patch to enlarge_size (float, optional): Enlarge the sampled patch to
[-block_size / 2 - enlarge_size, block_size / 2 + enlarge_size] as [-block_size / 2 - enlarge_size, block_size / 2 + enlarge_size] as
an augmentation. If None, set it as 0. Defaults to 0.2. an augmentation. If None, set it as 0. Defaults to 0.2.
min_unique_num (int | None, optional): Minimum number of unique points min_unique_num (int, optional): Minimum number of unique points
the sampled patch should contain. If None, use PointNet++'s method the sampled patch should contain. If None, use PointNet++'s method
to judge uniqueness. Defaults to None. to judge uniqueness. Defaults to None.
eps (float, optional): A value added to patch boundary to guarantee eps (float, optional): A value added to patch boundary to guarantee
...@@ -1040,7 +1057,7 @@ class IndoorPatchPointSample(object): ...@@ -1040,7 +1057,7 @@ class IndoorPatchPointSample(object):
attribute_dims, point_type): attribute_dims, point_type):
"""Generating model input. """Generating model input.
Generate input by subtracting patch center and adding additional \ Generate input by subtracting patch center and adding additional
features. Currently support colors and normalized xyz as features. features. Currently support colors and normalized xyz as features.
Args: Args:
...@@ -1184,7 +1201,7 @@ class IndoorPatchPointSample(object): ...@@ -1184,7 +1201,7 @@ class IndoorPatchPointSample(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \ dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
points = results['points'] points = results['points']
...@@ -1244,7 +1261,7 @@ class BackgroundPointsFilter(object): ...@@ -1244,7 +1261,7 @@ class BackgroundPointsFilter(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after filtering, 'points', 'pts_instance_mask' \ dict: Results after filtering, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
points = input_dict['points'] points = input_dict['points']
...@@ -1342,7 +1359,7 @@ class VoxelBasedPointSampler(object): ...@@ -1342,7 +1359,7 @@ class VoxelBasedPointSampler(object):
input_dict (dict): Result dict from loading pipeline. input_dict (dict): Result dict from loading pipeline.
Returns: Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \ dict: Results after sampling, 'points', 'pts_instance_mask'
and 'pts_semantic_mask' keys are updated in the result dict. and 'pts_semantic_mask' keys are updated in the result dict.
""" """
points = results['points'] points = results['points']
...@@ -1423,3 +1440,258 @@ class VoxelBasedPointSampler(object): ...@@ -1423,3 +1440,258 @@ class VoxelBasedPointSampler(object):
repr_str += ' ' * indent + 'prev_voxel_generator=\n' repr_str += ' ' * indent + 'prev_voxel_generator=\n'
repr_str += f'{_auto_indent(repr(self.prev_voxel_generator), 8)})' repr_str += f'{_auto_indent(repr(self.prev_voxel_generator), 8)})'
return repr_str return repr_str
@PIPELINES.register_module()
class AffineResize(object):
"""Get the affine transform matrices to the target size.
Different from :class:`RandomAffine` in MMDetection, this class can
calculate the affine transform matrices while resizing the input image
to a fixed size. The affine transform matrices include: 1) matrix
transforming original image to the network input image size. 2) matrix
transforming original image to the network output feature map size.
Args:
img_scale (tuple): Images scales for resizing.
down_ratio (int): The down ratio of feature map.
Actually the arg should be >= 1.
bbox_clip_border (bool, optional): Whether clip the objects
outside the border of the image. Defaults to True.
"""
def __init__(self, img_scale, down_ratio, bbox_clip_border=True):
self.img_scale = img_scale
self.down_ratio = down_ratio
self.bbox_clip_border = bbox_clip_border
def __call__(self, results):
"""Call function to do affine transform to input image and labels.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Results after affine resize, 'affine_aug', 'trans_mat'
keys are added in the result dict.
"""
# The results have gone through RandomShiftScale before AffineResize
if 'center' not in results:
img = results['img']
height, width = img.shape[:2]
center = np.array([width / 2, height / 2], dtype=np.float32)
size = np.array([width, height], dtype=np.float32)
results['affine_aug'] = False
else:
# The results did not go through RandomShiftScale before
# AffineResize
img = results['img']
center = results['center']
size = results['size']
trans_affine = self._get_transform_matrix(center, size, self.img_scale)
img = cv2.warpAffine(img, trans_affine[:2, :], self.img_scale)
if isinstance(self.down_ratio, tuple):
trans_mat = [
self._get_transform_matrix(
center, size,
(self.img_scale[0] // ratio, self.img_scale[1] // ratio))
for ratio in self.down_ratio
] # (3, 3)
else:
trans_mat = self._get_transform_matrix(
center, size, (self.img_scale[0] // self.down_ratio,
self.img_scale[1] // self.down_ratio))
results['img'] = img
results['img_shape'] = img.shape
results['pad_shape'] = img.shape
results['trans_mat'] = trans_mat
self._affine_bboxes(results, trans_affine)
if 'centers2d' in results:
centers2d = self._affine_transform(results['centers2d'],
trans_affine)
valid_index = (centers2d[:, 0] >
0) & (centers2d[:, 0] <
self.img_scale[0]) & (centers2d[:, 1] > 0) & (
centers2d[:, 1] < self.img_scale[1])
results['centers2d'] = centers2d[valid_index]
for key in results.get('bbox_fields', []):
if key in ['gt_bboxes']:
results[key] = results[key][valid_index]
if 'gt_labels' in results:
results['gt_labels'] = results['gt_labels'][
valid_index]
if 'gt_masks' in results:
raise NotImplementedError(
'AffineResize only supports bbox.')
for key in results.get('bbox3d_fields', []):
if key in ['gt_bboxes_3d']:
results[key].tensor = results[key].tensor[valid_index]
if 'gt_labels_3d' in results:
results['gt_labels_3d'] = results['gt_labels_3d'][
valid_index]
results['depths'] = results['depths'][valid_index]
return results
def _affine_bboxes(self, results, matrix):
"""Affine transform bboxes to input image.
Args:
results (dict): Result dict from loading pipeline.
matrix (np.ndarray): Matrix transforming original
image to the network input image size.
shape: (3, 3)
"""
for key in results.get('bbox_fields', []):
bboxes = results[key]
bboxes[:, :2] = self._affine_transform(bboxes[:, :2], matrix)
bboxes[:, 2:] = self._affine_transform(bboxes[:, 2:], matrix)
if self.bbox_clip_border:
bboxes[:,
[0, 2]] = bboxes[:,
[0, 2]].clip(0, self.img_scale[0] - 1)
bboxes[:,
[1, 3]] = bboxes[:,
[1, 3]].clip(0, self.img_scale[1] - 1)
results[key] = bboxes
def _affine_transform(self, points, matrix):
"""Affine transform bbox points to input image.
Args:
points (np.ndarray): Points to be transformed.
shape: (N, 2)
matrix (np.ndarray): Affine transform matrix.
shape: (3, 3)
Returns:
np.ndarray: Transformed points.
"""
num_points = points.shape[0]
hom_points_2d = np.concatenate((points, np.ones((num_points, 1))),
axis=1)
hom_points_2d = hom_points_2d.T
affined_points = np.matmul(matrix, hom_points_2d).T
return affined_points[:, :2]
def _get_transform_matrix(self, center, scale, output_scale):
"""Get affine transform matrix.
Args:
center (tuple): Center of current image.
scale (tuple): Scale of current image.
output_scale (tuple[float]): The transform target image scales.
Returns:
np.ndarray: Affine transform matrix.
"""
# TODO: further add rot and shift here.
src_w = scale[0]
dst_w = output_scale[0]
dst_h = output_scale[1]
src_dir = np.array([0, src_w * -0.5])
dst_dir = np.array([0, dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center
src[1, :] = center + src_dir
dst[0, :] = np.array([dst_w * 0.5, dst_h * 0.5])
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
src[2, :] = self._get_ref_point(src[0, :], src[1, :])
dst[2, :] = self._get_ref_point(dst[0, :], dst[1, :])
get_matrix = cv2.getAffineTransform(src, dst)
matrix = np.concatenate((get_matrix, [[0., 0., 1.]]))
return matrix.astype(np.float32)
def _get_ref_point(self, ref_point1, ref_point2):
"""Get reference point to calculate affine transform matrix.
While using opencv to calculate the affine matrix, we need at least
three corresponding points separately on original image and target
image. Here we use two points to get the the third reference point.
"""
d = ref_point1 - ref_point2
ref_point3 = ref_point2 + np.array([-d[1], d[0]])
return ref_point3
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(img_scale={self.img_scale}, '
repr_str += f'down_ratio={self.down_ratio}) '
return repr_str
@PIPELINES.register_module()
class RandomShiftScale(object):
"""Random shift scale.
Different from the normal shift and scale function, it doesn't
directly shift or scale image. It can record the shift and scale
infos into loading pipelines. It's designed to be used with
AffineResize together.
Args:
shift_scale (tuple[float]): Shift and scale range.
aug_prob (float): The shifting and scaling probability.
"""
def __init__(self, shift_scale, aug_prob):
self.shift_scale = shift_scale
self.aug_prob = aug_prob
def __call__(self, results):
"""Call function to record random shift and scale infos.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Results after random shift and scale, 'center', 'size'
and 'affine_aug' keys are added in the result dict.
"""
img = results['img']
height, width = img.shape[:2]
center = np.array([width / 2, height / 2], dtype=np.float32)
size = np.array([width, height], dtype=np.float32)
if random.random() < self.aug_prob:
shift, scale = self.shift_scale[0], self.shift_scale[1]
shift_ranges = np.arange(-shift, shift + 0.1, 0.1)
center[0] += size[0] * random.choice(shift_ranges)
center[1] += size[1] * random.choice(shift_ranges)
scale_ranges = np.arange(1 - scale, 1 + scale + 0.1, 0.1)
size *= random.choice(scale_ranges)
results['affine_aug'] = True
else:
results['affine_aug'] = False
results['center'] = center
results['size'] = size
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(shift_scale={self.shift_scale}, '
repr_str += f'aug_prob={self.aug_prob}) '
return repr_str
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from os import path as osp from os import path as osp
import numpy as np
from mmdet3d.core import show_seg_result from mmdet3d.core import show_seg_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import tempfile import tempfile
import warnings import warnings
from os import path as osp from os import path as osp
import numpy as np
from mmdet3d.core import show_result, show_seg_result from mmdet3d.core import show_result, show_seg_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
...@@ -78,13 +79,13 @@ class ScanNetDataset(Custom3DDataset): ...@@ -78,13 +79,13 @@ class ScanNetDataset(Custom3DDataset):
index (int): Index of the sample data to get. index (int): Index of the sample data to get.
Returns: Returns:
dict: Data information that will be passed to the data \ dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys: preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index. - sample_idx (str): Sample index.
- pts_filename (str): Filename of point clouds. - pts_filename (str): Filename of point clouds.
- file_name (str): Filename of point clouds. - file_name (str): Filename of point clouds.
- img_prefix (str | None, optional): Prefix of image files. - img_prefix (str, optional): Prefix of image files.
- img_info (dict, optional): Image info. - img_info (dict, optional): Image info.
- ann_info (dict): Annotation info. - ann_info (dict): Annotation info.
""" """
...@@ -129,12 +130,12 @@ class ScanNetDataset(Custom3DDataset): ...@@ -129,12 +130,12 @@ class ScanNetDataset(Custom3DDataset):
Returns: Returns:
dict: annotation information consists of the following keys: dict: annotation information consists of the following keys:
- gt_bboxes_3d (:obj:`DepthInstance3DBoxes`): \ - gt_bboxes_3d (:obj:`DepthInstance3DBoxes`):
3D ground truth bboxes 3D ground truth bboxes
- gt_labels_3d (np.ndarray): Labels of ground truths. - gt_labels_3d (np.ndarray): Labels of ground truths.
- pts_instance_mask_path (str): Path of instance masks. - pts_instance_mask_path (str): Path of instance masks.
- pts_semantic_mask_path (str): Path of semantic masks. - pts_semantic_mask_path (str): Path of semantic masks.
- axis_align_matrix (np.ndarray): Transformation matrix for \ - axis_align_matrix (np.ndarray): Transformation matrix for
global scene alignment. global scene alignment.
""" """
# Use index to get the annos, thus the evalhook could also use this api # Use index to get the annos, thus the evalhook could also use this api
...@@ -172,7 +173,7 @@ class ScanNetDataset(Custom3DDataset): ...@@ -172,7 +173,7 @@ class ScanNetDataset(Custom3DDataset):
def prepare_test_data(self, index): def prepare_test_data(self, index):
"""Prepare data for testing. """Prepare data for testing.
We should take axis_align_matrix from self.data_infos since we need \ We should take axis_align_matrix from self.data_infos since we need
to align point clouds. to align point clouds.
Args: Args:
...@@ -272,7 +273,7 @@ class ScanNetSegDataset(Custom3DSegDataset): ...@@ -272,7 +273,7 @@ class ScanNetSegDataset(Custom3DSegDataset):
as input. Defaults to None. as input. Defaults to None.
test_mode (bool, optional): Whether the dataset is in test mode. test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
ignore_index (int, optional): The label index to be ignored, e.g. \ ignore_index (int, optional): The label index to be ignored, e.g.
unannotated points. If None is given, set to len(self.CLASSES). unannotated points. If None is given, set to len(self.CLASSES).
Defaults to None. Defaults to None.
scene_idxs (np.ndarray | str, optional): Precomputed index to load scene_idxs (np.ndarray | str, optional): Precomputed index to load
...@@ -424,7 +425,7 @@ class ScanNetSegDataset(Custom3DSegDataset): ...@@ -424,7 +425,7 @@ class ScanNetSegDataset(Custom3DSegDataset):
Args: Args:
outputs (list[dict]): Testing results of the dataset. outputs (list[dict]): Testing results of the dataset.
txtfile_prefix (str | None): The prefix of saved files. It includes txtfile_prefix (str): The prefix of saved files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from collections import OrderedDict from collections import OrderedDict
from os import path as osp from os import path as osp
import numpy as np
from mmdet3d.core import show_multi_modality_result, show_result from mmdet3d.core import show_multi_modality_result, show_result
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet.core import eval_map from mmdet.core import eval_map
...@@ -74,13 +75,13 @@ class SUNRGBDDataset(Custom3DDataset): ...@@ -74,13 +75,13 @@ class SUNRGBDDataset(Custom3DDataset):
index (int): Index of the sample data to get. index (int): Index of the sample data to get.
Returns: Returns:
dict: Data information that will be passed to the data \ dict: Data information that will be passed to the data
preprocessing pipelines. It includes the following keys: preprocessing pipelines. It includes the following keys:
- sample_idx (str): Sample index. - sample_idx (str): Sample index.
- pts_filename (str, optional): Filename of point clouds. - pts_filename (str, optional): Filename of point clouds.
- file_name (str, optional): Filename of point clouds. - file_name (str, optional): Filename of point clouds.
- img_prefix (str | None, optional): Prefix of image files. - img_prefix (str, optional): Prefix of image files.
- img_info (dict, optional): Image info. - img_info (dict, optional): Image info.
- calib (dict, optional): Camera calibration info. - calib (dict, optional): Camera calibration info.
- ann_info (dict): Annotation info. - ann_info (dict): Annotation info.
...@@ -125,7 +126,7 @@ class SUNRGBDDataset(Custom3DDataset): ...@@ -125,7 +126,7 @@ class SUNRGBDDataset(Custom3DDataset):
Returns: Returns:
dict: annotation information consists of the following keys: dict: annotation information consists of the following keys:
- gt_bboxes_3d (:obj:`DepthInstance3DBoxes`): \ - gt_bboxes_3d (:obj:`DepthInstance3DBoxes`):
3D ground truth bboxes 3D ground truth bboxes
- gt_labels_3d (np.ndarray): Labels of ground truths. - gt_labels_3d (np.ndarray): Labels of ground truths.
- pts_instance_mask_path (str): Path of instance masks. - pts_instance_mask_path (str): Path of instance masks.
...@@ -239,12 +240,15 @@ class SUNRGBDDataset(Custom3DDataset): ...@@ -239,12 +240,15 @@ class SUNRGBDDataset(Custom3DDataset):
Args: Args:
results (list[dict]): List of results. results (list[dict]): List of results.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str], optional): Metrics to be evaluated.
iou_thr (list[float]): AP IoU thresholds. Default: None.
iou_thr_2d (list[float]): AP IoU thresholds for 2d evaluation. iou_thr (list[float], optional): AP IoU thresholds for 3D
show (bool): Whether to visualize. evaluation. Default: (0.25, 0.5).
iou_thr_2d (list[float], optional): AP IoU thresholds for 2D
evaluation. Default: (0.5, ).
show (bool, optional): Whether to visualize.
Default: False. Default: False.
out_dir (str): Path to save the visualization results. out_dir (str, optional): Path to save the visualization results.
Default: None. Default: None.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
......
...@@ -12,7 +12,7 @@ from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D, ...@@ -12,7 +12,7 @@ from mmdet3d.datasets.pipelines import (Collect3D, DefaultFormatBundle3D,
PointSegClassMapping) PointSegClassMapping)
# yapf: enable # yapf: enable
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import LoadImageFromFile from mmdet.datasets.pipelines import LoadImageFromFile, MultiScaleFlipAug
def is_loading_function(transform): def is_loading_function(transform):
...@@ -25,7 +25,7 @@ def is_loading_function(transform): ...@@ -25,7 +25,7 @@ def is_loading_function(transform):
transform (dict | :obj:`Pipeline`): A transform config or a function. transform (dict | :obj:`Pipeline`): A transform config or a function.
Returns: Returns:
bool | None: Whether it is a loading function. None means can't judge. bool: Whether it is a loading function. None means can't judge.
When transform is `MultiScaleFlipAug3D`, we return None. When transform is `MultiScaleFlipAug3D`, we return None.
""" """
# TODO: use more elegant way to distinguish loading modules # TODO: use more elegant way to distinguish loading modules
...@@ -40,12 +40,12 @@ def is_loading_function(transform): ...@@ -40,12 +40,12 @@ def is_loading_function(transform):
return False return False
if obj_cls in loading_functions: if obj_cls in loading_functions:
return True return True
if obj_cls in (MultiScaleFlipAug3D, ): if obj_cls in (MultiScaleFlipAug3D, MultiScaleFlipAug):
return None return None
elif callable(transform): elif callable(transform):
if isinstance(transform, loading_functions): if isinstance(transform, loading_functions):
return True return True
if isinstance(transform, MultiScaleFlipAug3D): if isinstance(transform, (MultiScaleFlipAug3D, MultiScaleFlipAug)):
return None return None
return False return False
...@@ -92,7 +92,7 @@ def get_loading_pipeline(pipeline): ...@@ -92,7 +92,7 @@ def get_loading_pipeline(pipeline):
... dict(type='Collect3D', ... dict(type='Collect3D',
... keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']) ... keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d'])
... ] ... ]
>>> assert expected_pipelines ==\ >>> assert expected_pipelines == \
... get_loading_pipeline(pipelines) ... get_loading_pipeline(pipelines)
""" """
loading_pipeline = [] loading_pipeline = []
...@@ -126,7 +126,7 @@ def extract_result_dict(results, key): ...@@ -126,7 +126,7 @@ def extract_result_dict(results, key):
key (str): Key of the desired data. key (str): Key of the desired data.
Returns: Returns:
np.ndarray | torch.Tensor | None: Data term. np.ndarray | torch.Tensor: Data term.
""" """
if key not in results.keys(): if key not in results.keys():
return None return None
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import os import os
import tempfile import tempfile
from os import path as osp
import mmcv
import numpy as np
import torch import torch
from mmcv.utils import print_log from mmcv.utils import print_log
from os import path as osp
from mmdet.datasets import DATASETS from mmdet.datasets import DATASETS
from ..core.bbox import Box3DMode, points_cam2img from ..core.bbox import Box3DMode, points_cam2img
...@@ -46,8 +47,9 @@ class WaymoDataset(KittiDataset): ...@@ -46,8 +47,9 @@ class WaymoDataset(KittiDataset):
Defaults to True. Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode. test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False. Defaults to False.
pcd_limit_range (list): The range of point cloud used to filter pcd_limit_range (list(float), optional): The range of point cloud used
invalid predicted boxes. Default: [-85, -85, -5, 85, 85, 5]. to filter invalid predicted boxes.
Default: [-85, -85, -5, 85, 85, 5].
""" """
CLASSES = ('Car', 'Cyclist', 'Pedestrian') CLASSES = ('Car', 'Cyclist', 'Pedestrian')
...@@ -100,7 +102,7 @@ class WaymoDataset(KittiDataset): ...@@ -100,7 +102,7 @@ class WaymoDataset(KittiDataset):
- sample_idx (str): sample index - sample_idx (str): sample index
- pts_filename (str): filename of point clouds - pts_filename (str): filename of point clouds
- img_prefix (str | None): prefix of image files - img_prefix (str): prefix of image files
- img_info (dict): image info - img_info (dict): image info
- lidar2img (list[np.ndarray], optional): transformations from - lidar2img (list[np.ndarray], optional): transformations from
lidar to different cameras lidar to different cameras
...@@ -140,15 +142,15 @@ class WaymoDataset(KittiDataset): ...@@ -140,15 +142,15 @@ class WaymoDataset(KittiDataset):
Args: Args:
outputs (list[dict]): Testing results of the dataset. outputs (list[dict]): Testing results of the dataset.
pklfile_prefix (str | None): The prefix of pkl files. It includes pklfile_prefix (str): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
submission_prefix (str | None): The prefix of submitted files. It submission_prefix (str): The prefix of submitted files. It
includes the file path and the prefix of filename, e.g., includes the file path and the prefix of filename, e.g.,
"a/b/prefix". If not specified, a temp file will be created. "a/b/prefix". If not specified, a temp file will be created.
Default: None. Default: None.
data_format (str | None): Output data format. Default: 'waymo'. data_format (str, optional): Output data format.
Another supported choice is 'kitti'. Default: 'waymo'. Another supported choice is 'kitti'.
Returns: Returns:
tuple: (result_files, tmp_dir), result_files is a dict containing tuple: (result_files, tmp_dir), result_files is a dict containing
...@@ -226,18 +228,18 @@ class WaymoDataset(KittiDataset): ...@@ -226,18 +228,18 @@ class WaymoDataset(KittiDataset):
Args: Args:
results (list[dict]): Testing results of the dataset. results (list[dict]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str], optional): Metrics to be evaluated.
Default: 'waymo'. Another supported metric is 'kitti'. Default: 'waymo'. Another supported metric is 'kitti'.
logger (logging.Logger | str | None): Logger used for printing logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
pklfile_prefix (str | None): The prefix of pkl files. It includes pklfile_prefix (str, optional): The prefix of pkl files including
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
submission_prefix (str | None): The prefix of submission datas. submission_prefix (str, optional): The prefix of submission data.
If not specified, the submission data will not be generated. If not specified, the submission data will not be generated.
show (bool): Whether to visualize. show (bool, optional): Whether to visualize.
Default: False. Default: False.
out_dir (str): Path to save the visualization results. out_dir (str, optional): Path to save the visualization results.
Default: None. Default: None.
pipeline (list[dict], optional): raw data loading for showing. pipeline (list[dict], optional): raw data loading for showing.
Default: None. Default: None.
...@@ -349,8 +351,8 @@ class WaymoDataset(KittiDataset): ...@@ -349,8 +351,8 @@ class WaymoDataset(KittiDataset):
if tmp_dir is not None: if tmp_dir is not None:
tmp_dir.cleanup() tmp_dir.cleanup()
if show: if show or out_dir:
self.show(results, out_dir, pipeline=pipeline) self.show(results, out_dir, show=show, pipeline=pipeline)
return ap_dict return ap_dict
def bbox2result_kitti(self, def bbox2result_kitti(self,
...@@ -364,8 +366,8 @@ class WaymoDataset(KittiDataset): ...@@ -364,8 +366,8 @@ class WaymoDataset(KittiDataset):
net_outputs (List[np.ndarray]): list of array storing the net_outputs (List[np.ndarray]): list of array storing the
bbox and score bbox and score
class_nanes (List[String]): A list of class names class_nanes (List[String]): A list of class names
pklfile_prefix (str | None): The prefix of pkl file. pklfile_prefix (str): The prefix of pkl file.
submission_prefix (str | None): The prefix of submission file. submission_prefix (str): The prefix of submission file.
Returns: Returns:
List[dict]: A list of dict have the kitti 3d format List[dict]: A list of dict have the kitti 3d format
...@@ -494,7 +496,6 @@ class WaymoDataset(KittiDataset): ...@@ -494,7 +496,6 @@ class WaymoDataset(KittiDataset):
scores = box_dict['scores_3d'] scores = box_dict['scores_3d']
labels = box_dict['labels_3d'] labels = box_dict['labels_3d']
sample_idx = info['image']['image_idx'] sample_idx = info['image']['image_idx']
# TODO: remove the hack of yaw
box_preds.limit_yaw(offset=0.5, period=np.pi * 2) box_preds.limit_yaw(offset=0.5, period=np.pi * 2)
if len(box_preds) == 0: if len(box_preds) == 0:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt
from .dgcnn import DGCNNBackbone
from .dla import DLANet
from .multi_backbone import MultiBackbone from .multi_backbone import MultiBackbone
from .nostem_regnet import NoStemRegNet from .nostem_regnet import NoStemRegNet
from .pointnet2_sa_msg import PointNet2SAMSG from .pointnet2_sa_msg import PointNet2SAMSG
...@@ -8,5 +10,6 @@ from .second import SECOND ...@@ -8,5 +10,6 @@ from .second import SECOND
__all__ = [ __all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'PointNet2SASSG', 'PointNet2SAMSG', 'MultiBackbone' 'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
'MultiBackbone', 'DLANet'
] ]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from abc import ABCMeta from abc import ABCMeta
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn
from mmdet3d.ops import DGCNNFAModule, DGCNNGFModule
from mmdet.models import BACKBONES
@BACKBONES.register_module()
class DGCNNBackbone(BaseModule):
"""Backbone network for DGCNN.
Args:
in_channels (int): Input channels of point cloud.
num_samples (tuple[int], optional): The number of samples for knn or
ball query in each graph feature (GF) module.
Defaults to (20, 20, 20).
knn_modes (tuple[str], optional): Mode of KNN of each knn module.
Defaults to ('D-KNN', 'F-KNN', 'F-KNN').
radius (tuple[float], optional): Sampling radii of each GF module.
Defaults to (None, None, None).
gf_channels (tuple[tuple[int]], optional): Out channels of each mlp in
GF module. Defaults to ((64, 64), (64, 64), (64, )).
fa_channels (tuple[int], optional): Out channels of each mlp in FA
module. Defaults to (1024, ).
act_cfg (dict, optional): Config of activation layer.
Defaults to dict(type='ReLU').
init_cfg (dict, optional): Initialization config.
Defaults to None.
"""
def __init__(self,
in_channels,
num_samples=(20, 20, 20),
knn_modes=('D-KNN', 'F-KNN', 'F-KNN'),
radius=(None, None, None),
gf_channels=((64, 64), (64, 64), (64, )),
fa_channels=(1024, ),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.num_gf = len(gf_channels)
assert len(num_samples) == len(knn_modes) == len(radius) == len(
gf_channels), 'Num_samples, knn_modes, radius and gf_channels \
should have the same length.'
self.GF_modules = nn.ModuleList()
gf_in_channel = in_channels * 2
skip_channel_list = [gf_in_channel] # input channel list
for gf_index in range(self.num_gf):
cur_gf_mlps = list(gf_channels[gf_index])
cur_gf_mlps = [gf_in_channel] + cur_gf_mlps
gf_out_channel = cur_gf_mlps[-1]
self.GF_modules.append(
DGCNNGFModule(
mlp_channels=cur_gf_mlps,
num_sample=num_samples[gf_index],
knn_mode=knn_modes[gf_index],
radius=radius[gf_index],
act_cfg=act_cfg))
skip_channel_list.append(gf_out_channel)
gf_in_channel = gf_out_channel * 2
fa_in_channel = sum(skip_channel_list[1:])
cur_fa_mlps = list(fa_channels)
cur_fa_mlps = [fa_in_channel] + cur_fa_mlps
self.FA_module = DGCNNFAModule(
mlp_channels=cur_fa_mlps, act_cfg=act_cfg)
@auto_fp16(apply_to=('points', ))
def forward(self, points):
"""Forward pass.
Args:
points (torch.Tensor): point coordinates with features,
with shape (B, N, in_channels).
Returns:
dict[str, list[torch.Tensor]]: Outputs after graph feature (GF) and
feature aggregation (FA) modules.
- gf_points (list[torch.Tensor]): Outputs after each GF module.
- fa_points (torch.Tensor): Outputs after FA module.
"""
gf_points = [points]
for i in range(self.num_gf):
cur_points = self.GF_modules[i](gf_points[i])
gf_points.append(cur_points)
fa_points = self.FA_module(gf_points)
out = dict(gf_points=gf_points, fa_points=fa_points)
return out
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule
from torch import nn
from mmdet.models.builder import BACKBONES
def dla_build_norm_layer(cfg, num_features):
"""Build normalization layer specially designed for DLANet.
Args:
cfg (dict): The norm layer config, which should contain:
- type (str): Layer type.
- layer args: Args needed to instantiate a norm layer.
- requires_grad (bool, optional): Whether stop gradient updates.
num_features (int): Number of input channels.
Returns:
Function: Build normalization layer in mmcv.
"""
cfg_ = cfg.copy()
if cfg_['type'] == 'GN':
if num_features % 32 == 0:
return build_norm_layer(cfg_, num_features)
else:
assert 'num_groups' in cfg_
cfg_['num_groups'] = cfg_['num_groups'] // 2
return build_norm_layer(cfg_, num_features)
else:
return build_norm_layer(cfg_, num_features)
class BasicBlock(BaseModule):
"""BasicBlock in DLANet.
Args:
in_channels (int): Input feature channel.
out_channels (int): Output feature channel.
norm_cfg (dict): Dictionary to construct and config
norm layer.
conv_cfg (dict): Dictionary to construct and config
conv layer.
stride (int, optional): Conv stride. Default: 1.
dilation (int, optional): Conv dilation. Default: 1.
init_cfg (dict, optional): Initialization config.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride=1,
dilation=1,
init_cfg=None):
super(BasicBlock, self).__init__(init_cfg)
self.conv1 = build_conv_layer(
conv_cfg,
in_channels,
out_channels,
3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False)
self.norm1 = dla_build_norm_layer(norm_cfg, out_channels)[1]
self.relu = nn.ReLU(inplace=True)
self.conv2 = build_conv_layer(
conv_cfg,
out_channels,
out_channels,
3,
stride=1,
padding=dilation,
dilation=dilation,
bias=False)
self.norm2 = dla_build_norm_layer(norm_cfg, out_channels)[1]
self.stride = stride
def forward(self, x, identity=None):
"""Forward function."""
if identity is None:
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out += identity
out = self.relu(out)
return out
class Root(BaseModule):
"""Root in DLANet.
Args:
in_channels (int): Input feature channel.
out_channels (int): Output feature channel.
norm_cfg (dict): Dictionary to construct and config
norm layer.
conv_cfg (dict): Dictionary to construct and config
conv layer.
kernel_size (int): Size of convolution kernel.
add_identity (bool): Whether to add identity in root.
init_cfg (dict, optional): Initialization config.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
kernel_size,
add_identity,
init_cfg=None):
super(Root, self).__init__(init_cfg)
self.conv = build_conv_layer(
conv_cfg,
in_channels,
out_channels,
1,
stride=1,
padding=(kernel_size - 1) // 2,
bias=False)
self.norm = dla_build_norm_layer(norm_cfg, out_channels)[1]
self.relu = nn.ReLU(inplace=True)
self.add_identity = add_identity
def forward(self, feat_list):
"""Forward function.
Args:
feat_list (list[torch.Tensor]): Output features from
multiple layers.
"""
children = feat_list
x = self.conv(torch.cat(feat_list, 1))
x = self.norm(x)
if self.add_identity:
x += children[0]
x = self.relu(x)
return x
class Tree(BaseModule):
"""Tree in DLANet.
Args:
levels (int): The level of the tree.
block (nn.Module): The block module in tree.
in_channels: Input feature channel.
out_channels: Output feature channel.
norm_cfg (dict): Dictionary to construct and config
norm layer.
conv_cfg (dict): Dictionary to construct and config
conv layer.
stride (int, optional): Convolution stride.
Default: 1.
level_root (bool, optional): whether belongs to the
root layer.
root_dim (int, optional): Root input feature channel.
root_kernel_size (int, optional): Size of root
convolution kernel. Default: 1.
dilation (int, optional): Conv dilation. Default: 1.
add_identity (bool, optional): Whether to add
identity in root. Default: False.
init_cfg (dict, optional): Initialization config.
Default: None.
"""
def __init__(self,
levels,
block,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride=1,
level_root=False,
root_dim=None,
root_kernel_size=1,
dilation=1,
add_identity=False,
init_cfg=None):
super(Tree, self).__init__(init_cfg)
if root_dim is None:
root_dim = 2 * out_channels
if level_root:
root_dim += in_channels
if levels == 1:
self.root = Root(root_dim, out_channels, norm_cfg, conv_cfg,
root_kernel_size, add_identity)
self.tree1 = block(
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride,
dilation=dilation)
self.tree2 = block(
out_channels,
out_channels,
norm_cfg,
conv_cfg,
1,
dilation=dilation)
else:
self.tree1 = Tree(
levels - 1,
block,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride,
root_dim=None,
root_kernel_size=root_kernel_size,
dilation=dilation,
add_identity=add_identity)
self.tree2 = Tree(
levels - 1,
block,
out_channels,
out_channels,
norm_cfg,
conv_cfg,
root_dim=root_dim + out_channels,
root_kernel_size=root_kernel_size,
dilation=dilation,
add_identity=add_identity)
self.level_root = level_root
self.root_dim = root_dim
self.downsample = None
self.project = None
self.levels = levels
if stride > 1:
self.downsample = nn.MaxPool2d(stride, stride=stride)
if in_channels != out_channels:
self.project = nn.Sequential(
build_conv_layer(
conv_cfg,
in_channels,
out_channels,
1,
stride=1,
bias=False),
dla_build_norm_layer(norm_cfg, out_channels)[1])
def forward(self, x, identity=None, children=None):
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
identity = self.project(bottom) if self.project else bottom
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, identity)
if self.levels == 1:
x2 = self.tree2(x1)
feat_list = [x2, x1] + children
x = self.root(feat_list)
else:
children.append(x1)
x = self.tree2(x1, children=children)
return x
@BACKBONES.register_module()
class DLANet(BaseModule):
r"""`DLA backbone <https://arxiv.org/abs/1707.06484>`_.
Args:
depth (int): Depth of DLA. Default: 34.
in_channels (int, optional): Number of input image channels.
Default: 3.
norm_cfg (dict, optional): Dictionary to construct and config
norm layer. Default: None.
conv_cfg (dict, optional): Dictionary to construct and config
conv layer. Default: None.
layer_with_level_root (list[bool], optional): Whether to apply
level_root in each DLA layer, this is only used for
tree levels. Default: (False, True, True, True).
with_identity_root (bool, optional): Whether to add identity
in root layer. Default: False.
pretrained (str, optional): model pretrained path.
Default: None.
init_cfg (dict or list[dict], optional): Initialization
config dict. Default: None
"""
arch_settings = {
34: (BasicBlock, (1, 1, 1, 2, 2, 1), (16, 32, 64, 128, 256, 512)),
}
def __init__(self,
depth,
in_channels=3,
out_indices=(0, 1, 2, 3, 4, 5),
frozen_stages=-1,
norm_cfg=None,
conv_cfg=None,
layer_with_level_root=(False, True, True, True),
with_identity_root=False,
pretrained=None,
init_cfg=None):
super(DLANet, self).__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalida depth {depth} for DLA')
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
block, levels, channels = self.arch_settings[depth]
self.channels = channels
self.num_levels = len(levels)
self.frozen_stages = frozen_stages
self.out_indices = out_indices
assert max(out_indices) < self.num_levels
self.base_layer = nn.Sequential(
build_conv_layer(
conv_cfg,
in_channels,
channels[0],
7,
stride=1,
padding=3,
bias=False),
dla_build_norm_layer(norm_cfg, channels[0])[1],
nn.ReLU(inplace=True))
# DLANet first uses two conv layers then uses several
# Tree layers
for i in range(2):
level_layer = self._make_conv_level(
channels[0],
channels[i],
levels[i],
norm_cfg,
conv_cfg,
stride=i + 1)
layer_name = f'level{i}'
self.add_module(layer_name, level_layer)
for i in range(2, self.num_levels):
dla_layer = Tree(
levels[i],
block,
channels[i - 1],
channels[i],
norm_cfg,
conv_cfg,
2,
level_root=layer_with_level_root[i - 2],
add_identity=with_identity_root)
layer_name = f'level{i}'
self.add_module(layer_name, dla_layer)
self._freeze_stages()
def _make_conv_level(self,
in_channels,
out_channels,
num_convs,
norm_cfg,
conv_cfg,
stride=1,
dilation=1):
"""Conv modules.
Args:
in_channels (int): Input feature channel.
out_channels (int): Output feature channel.
num_convs (int): Number of Conv module.
norm_cfg (dict): Dictionary to construct and config
norm layer.
conv_cfg (dict): Dictionary to construct and config
conv layer.
stride (int, optional): Conv stride. Default: 1.
dilation (int, optional): Conv dilation. Default: 1.
"""
modules = []
for i in range(num_convs):
modules.extend([
build_conv_layer(
conv_cfg,
in_channels,
out_channels,
3,
stride=stride if i == 0 else 1,
padding=dilation,
bias=False,
dilation=dilation),
dla_build_norm_layer(norm_cfg, out_channels)[1],
nn.ReLU(inplace=True)
])
in_channels = out_channels
return nn.Sequential(*modules)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.base_layer.eval()
for param in self.base_layer.parameters():
param.requires_grad = False
for i in range(2):
m = getattr(self, f'level{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'level{i+1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def forward(self, x):
outs = []
x = self.base_layer(x)
for i in range(self.num_levels):
x = getattr(self, 'level{}'.format(i))(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import torch
import warnings import warnings
import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16 from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn from torch import nn as nn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment