Unverified Commit 5b88c7b8 authored by Sun Jiahao's avatar Sun Jiahao Committed by GitHub
Browse files

[Refactor] Refactor Waymo dataset_converter/dataset/evaluator (#2836)

Co-authored-by: sjh <sunjiahao1999>
parent 395b86d7
...@@ -134,3 +134,4 @@ data/sunrgbd/OFFICIAL_SUNRGBD/ ...@@ -134,3 +134,4 @@ data/sunrgbd/OFFICIAL_SUNRGBD/
# Waymo evaluation # Waymo evaluation
mmdet3d/evaluation/functional/waymo_utils/compute_detection_metrics_main mmdet3d/evaluation/functional/waymo_utils/compute_detection_metrics_main
mmdet3d/evaluation/functional/waymo_utils/compute_detection_let_metrics_main mmdet3d/evaluation/functional/waymo_utils/compute_detection_let_metrics_main
mmdet3d/evaluation/functional/waymo_utils/compute_segmentation_metrics_main
...@@ -89,7 +89,10 @@ test_pipeline = [ ...@@ -89,7 +89,10 @@ test_pipeline = [
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range) type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]), ]),
dict(type='Pack3DDetInputs', keys=['points']) dict(
type='Pack3DDetInputs',
keys=['points'],
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
] ]
# construct a pipeline for data and gt loading in show function # construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client) # please keep its loading function consistent with test_pipeline (e.g. client)
...@@ -100,7 +103,10 @@ eval_pipeline = [ ...@@ -100,7 +103,10 @@ eval_pipeline = [
load_dim=6, load_dim=6,
use_dim=5, use_dim=5,
backend_args=backend_args), backend_args=backend_args),
dict(type='Pack3DDetInputs', keys=['points']), dict(
type='Pack3DDetInputs',
keys=['points'],
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
] ]
train_dataloader = dict( train_dataloader = dict(
...@@ -164,12 +170,7 @@ test_dataloader = dict( ...@@ -164,12 +170,7 @@ test_dataloader = dict(
backend_args=backend_args)) backend_args=backend_args))
val_evaluator = dict( val_evaluator = dict(
type='WaymoMetric', type='WaymoMetric', waymo_bin_file='./data/waymo/waymo_format/gt.bin')
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
backend_args=backend_args,
convert_kitti_format=False)
test_evaluator = val_evaluator test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')] vis_backends = [dict(type='LocalVisBackend')]
......
...@@ -62,7 +62,8 @@ train_pipeline = [ ...@@ -62,7 +62,8 @@ train_pipeline = [
dict(type='PointShuffle'), dict(type='PointShuffle'),
dict( dict(
type='Pack3DDetInputs', type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) keys=['points'],
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
] ]
test_pipeline = [ test_pipeline = [
dict( dict(
...@@ -86,7 +87,10 @@ test_pipeline = [ ...@@ -86,7 +87,10 @@ test_pipeline = [
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range) type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]), ]),
dict(type='Pack3DDetInputs', keys=['points']) dict(
type='Pack3DDetInputs',
keys=['points'],
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
] ]
# construct a pipeline for data and gt loading in show function # construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client) # please keep its loading function consistent with test_pipeline (e.g. client)
...@@ -161,12 +165,7 @@ test_dataloader = dict( ...@@ -161,12 +165,7 @@ test_dataloader = dict(
backend_args=backend_args)) backend_args=backend_args))
val_evaluator = dict( val_evaluator = dict(
type='WaymoMetric', type='WaymoMetric', waymo_bin_file='./data/waymo/waymo_format/gt.bin')
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
convert_kitti_format=False,
backend_args=backend_args)
test_evaluator = val_evaluator test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')] vis_backends = [dict(type='LocalVisBackend')]
......
...@@ -7,12 +7,7 @@ This page provides specific tutorials about the usage of MMDetection3D for Waymo ...@@ -7,12 +7,7 @@ This page provides specific tutorials about the usage of MMDetection3D for Waymo
Before preparing Waymo dataset, if you only installed requirements in `requirements/build.txt` and `requirements/runtime.txt` before, please install the official package for this dataset at first by running Before preparing Waymo dataset, if you only installed requirements in `requirements/build.txt` and `requirements/runtime.txt` before, please install the official package for this dataset at first by running
``` ```
# tf 2.1.0. pip install waymo-open-dataset-tf-2-6-0
pip install waymo-open-dataset-tf-2-1-0==1.2.0
# tf 2.0.0
# pip install waymo-open-dataset-tf-2-0-0==1.2.0
# tf 1.15.0
# pip install waymo-open-dataset-tf-1-15-0==1.2.0
``` ```
or or
...@@ -38,15 +33,19 @@ mmdetection3d ...@@ -38,15 +33,19 @@ mmdetection3d
│ │ │ ├── validation │ │ │ ├── validation
│ │ │ ├── testing │ │ │ ├── testing
│ │ │ ├── gt.bin │ │ │ ├── gt.bin
│ │ │ ├── cam_gt.bin
│ │ │ ├── fov_gt.bin
│ │ ├── kitti_format │ │ ├── kitti_format
│ │ │ ├── ImageSets │ │ │ ├── ImageSets
``` ```
You can download Waymo open dataset V1.2 [HERE](https://waymo.com/open/download/) and its data split [HERE](https://drive.google.com/drive/folders/18BVuF_RYJF0NjZpt8SnfzANiakoRMf0o?usp=sharing). Then put `tfrecord` files into corresponding folders in `data/waymo/waymo_format/` and put the data split txt files into `data/waymo/kitti_format/ImageSets`. Download ground truth bin files for validation set [HERE](https://console.cloud.google.com/storage/browser/waymo_open_dataset_v_1_2_0/validation/ground_truth_objects) and put it into `data/waymo/waymo_format/`. A tip is that you can use `gsutil` to download the large-scale dataset with commands. You can take this [tool](https://github.com/RalphMao/Waymo-Dataset-Tool) as an example for more details. Subsequently, prepare Waymo data by running You can download Waymo open dataset V1.4 [HERE](https://waymo.com/open/download/) and its data split [HERE](https://drive.google.com/drive/folders/18BVuF_RYJF0NjZpt8SnfzANiakoRMf0o?usp=sharing). Then put `tfrecord` files into corresponding folders in `data/waymo/waymo_format/` and put the data split txt files into `data/waymo/kitti_format/ImageSets`. Download ground truth bin files for validation set [HERE](https://console.cloud.google.com/storage/browser/waymo_open_dataset_v_1_2_0/validation/ground_truth_objects) and put it into `data/waymo/waymo_format/`. A tip is that you can use `gsutil` to download the large-scale dataset with commands. You can take this [tool](https://github.com/RalphMao/Waymo-Dataset-Tool) as an example for more details. Subsequently, prepare Waymo data by running
```bash ```bash
python tools/create_data.py waymo --root-path ./data/waymo/ --out-dir ./data/waymo/ --workers 128 --extra-tag waymo # TF_CPP_MIN_LOG_LEVEL=3 will disable all logging output from TensorFlow.
# The number of `--workers` depends on the maximum number of cores in your CPU.
TF_CPP_MIN_LOG_LEVEL=3 python tools/create_data.py waymo --root-path ./data/waymo --out-dir ./data/waymo --workers 128 --extra-tag waymo --version v1.4
``` ```
Note that if your local disk does not have enough space for saving converted data, you can change the `--out-dir` to anywhere else. Just remember to create folders and prepare data there in advance and link them back to `data/waymo/kitti_format` after the data conversion. Note that if your local disk does not have enough space for saving converted data, you can change the `--out-dir` to anywhere else. Just remember to create folders and prepare data there in advance and link them back to `data/waymo/kitti_format` after the data conversion.
...@@ -65,22 +64,16 @@ mmdetection3d ...@@ -65,22 +64,16 @@ mmdetection3d
│ │ │ ├── validation │ │ │ ├── validation
│ │ │ ├── testing │ │ │ ├── testing
│ │ │ ├── gt.bin │ │ │ ├── gt.bin
│ │ │ ├── cam_gt.bin
│ │ │ ├── fov_gt.bin
│ │ ├── kitti_format │ │ ├── kitti_format
│ │ │ ├── ImageSets │ │ │ ├── ImageSets
│ │ │ ├── training │ │ │ ├── training
│ │ │ │ ├── calib
│ │ │ │ ├── image_0 │ │ │ │ ├── image_0
│ │ │ │ ├── image_1 │ │ │ │ ├── image_1
│ │ │ │ ├── image_2 │ │ │ │ ├── image_2
│ │ │ │ ├── image_3 │ │ │ │ ├── image_3
│ │ │ │ ├── image_4 │ │ │ │ ├── image_4
│ │ │ │ ├── label_0
│ │ │ │ ├── label_1
│ │ │ │ ├── label_2
│ │ │ │ ├── label_3
│ │ │ │ ├── label_4
│ │ │ │ ├── label_all
│ │ │ │ ├── pose
│ │ │ │ ├── velodyne │ │ │ │ ├── velodyne
│ │ │ ├── testing │ │ │ ├── testing
│ │ │ │ ├── (the same as training) │ │ │ │ ├── (the same as training)
...@@ -93,7 +86,48 @@ mmdetection3d ...@@ -93,7 +86,48 @@ mmdetection3d
``` ```
Here because there are several cameras, we store the corresponding image and labels that can be projected to that camera respectively and save pose for further usage of consecutive frames point clouds. We use a coding way `{a}{bbb}{ccc}` to name the data for each frame, where `a` is the prefix for different split (`0` for training, `1` for validation and `2` for testing), `bbb` for segment index and `ccc` for frame index. You can easily locate the required frame according to this naming rule. We gather the data for training and validation together as KITTI and store the indices for different set in the `ImageSet` files. - `kitti_format/training/image_{0-4}/{a}{bbb}{ccc}.jpg` Here because there are several cameras, we store the corresponding images. We use a coding way `{a}{bbb}{ccc}` to name the data for each frame, where `a` is the prefix for different split (`0` for training, `1` for validation and `2` for testing), `bbb` for segment index and `ccc` for frame index. You can easily locate the required frame according to this naming rule. We gather the data for training and validation together as KITTI and store the indices for different set in the `ImageSet` files.
- `kitti_format/training/velodyne/{a}{bbb}{ccc}.bin` point cloud data for each frame.
- `kitti_format/waymo_gt_database/xxx_{Car/Pedestrian/Cyclist}_x.bin`. point cloud data included in each 3D bounding box of the training dataset. These point clouds will be used in data augmentation e.g. `ObjectSample`. `xxx` is the index of training samples and `x` is the index of objects in this frame.
- `kitti_format/waymo_infos_train.pkl`. training dataset information, a dict contains two keys: `metainfo` and `data_list`.`metainfo` contains the basic information for the dataset itself, such as `dataset`, `version` and `info_version`, while `data_list` is a list of dict, each dict (hereinafter referred to as `info`) contains all the detailed information of single sample as follows:
- info\['sample_idx'\]: The index of this sample in the whole dataset.
- info\['ego2global'\]: The transformation matrix from the ego vehicle to global coordinates. (4x4 list).
- info\['timestamp'\]: Timestamp of the sample data.
- info\['context_name'\]: The context name of sample indices which `*.tfrecord` segment it extracted from.
- info\['lidar_points'\]: A dict containing all the information related to the lidar points.
- info\['lidar_points'\]\['lidar_path'\]: The filename of the lidar point cloud data.
- info\['lidar_points'\]\['num_pts_feats'\]: The feature dimension of point.
- info\['lidar_sweeps'\]: A list contains sweeps information of lidar
- info\['lidar_sweeps'\]\[i\]\['lidar_points'\]\['lidar_path'\]: The lidar data path of i-th sweep.
- info\['lidar_sweeps'\]\[i\]\['ego2global'\]: The transformation matrix from the ego vehicle to global coordinates. (4x4 list)
- info\['lidar_sweeps'\]\[i\]\['timestamp'\]: Timestamp of the sweep data.
- info\['images'\]: A dict contains five keys corresponding to each camera: `'CAM_FRONT'`, `'CAM_FRONT_RIGHT'`, `'CAM_FRONT_LEFT'`, `'CAM_SIDE_LEFT'`, `'CAM_SIDE_RIGHT'`. Each dict contains all data information related to corresponding camera.
- info\['images'\]\['CAM_XXX'\]\['img_path'\]: The filename of the image.
- info\['images'\]\['CAM_XXX'\]\['height'\]: The height of the image.
- info\['images'\]\['CAM_XXX'\]\['width'\]: The width of the image.
- info\['images'\]\['CAM_XXX'\]\['cam2img'\]: The transformation matrix recording the intrinsic parameters when projecting 3D points to each image plane. (4x4 list)
- info\['images'\]\['CAM_XXX'\]\['lidar2cam'\]: The transformation matrix from lidar sensor to this camera. (4x4 list)
- info\['images'\]\['CAM_XXX'\]\['lidar2img'\]: The transformation matrix from lidar sensor to each image plane. (4x4 list)
- info\['image_sweeps'\]: A list containing sweeps information of images.
- info\['image_sweeps'\]\[i\]\['images'\]\['CAM_XXX'\]\['img_path'\]: The image path of i-th sweep.
- info\['image_sweeps'\]\[i\]\['ego2global'\]: The transformation matrix from the ego vehicle to global coordinates. (4x4 list)
- info\['image_sweeps'\]\[i\]\['timestamp'\]: Timestamp of the sweep data.
- info\['instances'\]: It is a list of dict. Each dict contains all annotation information of single instance. For the i-th instance:
- info\['instances'\]\[i\]\['bbox_3d'\]: List of 7 numbers representing the 3D bounding box of the instance, in (x, y, z, l, w, h, yaw) order.
- info\['instances'\]\[i\]\['bbox'\]: List of 4 numbers representing the 2D bounding box of the instance, in (x1, y1, x2, y2) order. (some instances may not have a corresponding 2D bounding box)
- info\['instances'\]\[i\]\['bbox_label_3d'\]: A int indicating the label of instance and the -1 indicating ignore.
- info\['instances'\]\[i\]\['bbox_label'\]: A int indicating the label of instance and the -1 indicating ignore.
- info\['instances'\]\[i\]\['num_lidar_pts'\]: Number of lidar points included in each 3D bounding box.
- info\['instances'\]\[i\]\['camera_id'\]: The index of the most visible camera for this instance.
- info\['instances'\]\[i\]\['group_id'\]: The index of this instance in this sample.
- info\['cam_sync_instances'\]: It is a list of dict. Each dict contains all annotation information of single instance. Its format is same with \['instances'\]. However, \['cam_sync_instances'\] is only for multi-view camera-based 3D Object Detection task.
- info\['cam_instances'\]: It is a dict containing keys `'CAM_FRONT'`, `'CAM_FRONT_RIGHT'`, `'CAM_FRONT_LEFT'`, `'CAM_SIDE_LEFT'`, `'CAM_SIDE_RIGHT'`. For monocular camera-based 3D Object Detection task, we split 3D annotations of the whole scenes according to the camera they belong to. For the i-th instance:
- info\['cam_instances'\]\['CAM_XXX'\]\[i\]\['bbox_3d'\]: List of 7 numbers representing the 3D bounding box of the instance, in (x, y, z, l, h, w, yaw) order.
- info\['cam_instances'\]\['CAM_XXX'\]\[i\]\['bbox'\]: 2D bounding box annotation (exterior rectangle of the projected 3D box), a list arrange as \[x1, y1, x2, y2\].
- info\['cam_instances'\]\['CAM_XXX'\]\[i\]\['bbox_label_3d'\]: Label of instance.
- info\['cam_instances'\]\['CAM_XXX'\]\[i\]\['bbox_label'\]: Label of instance.
- info\['cam_instances'\]\['CAM_XXX'\]\[i\]\['center_2d'\]: Projected center location on the image, a list has shape (2,).
- info\['cam_instances'\]\['CAM_XXX'\]\[i\]\['depth'\]: The depth of projected center.
## Training ## Training
...@@ -101,7 +135,7 @@ Considering there are many similar frames in the original dataset, we can basica ...@@ -101,7 +135,7 @@ Considering there are many similar frames in the original dataset, we can basica
## Evaluation ## Evaluation
For evaluation on Waymo, please follow the [instruction](https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md) to build the binary file `compute_detection_metrics_main` for metrics computation and put it into `mmdet3d/core/evaluation/waymo_utils/`. Basically, you can follow the commands below to install `bazel` and build the file. For evaluation on Waymo, please follow the [instruction](https://github.com/waymo-research/waymo-open-dataset/blob/r1.3/docs/quick_start.md) to build the binary file `compute_detection_metrics_main` for metrics computation and put it into `mmdet3d/core/evaluation/waymo_utils/`. Basically, you can follow the commands below to install `bazel` and build the file.
```shell ```shell
# download the code and enter the base directory # download the code and enter the base directory
......
...@@ -7,12 +7,7 @@ ...@@ -7,12 +7,7 @@
在准备 Waymo 数据集之前,如果您之前只安装了 `requirements/build.txt``requirements/runtime.txt` 中的依赖,请通过运行如下指令额外安装 Waymo 数据集所依赖的官方包: 在准备 Waymo 数据集之前,如果您之前只安装了 `requirements/build.txt``requirements/runtime.txt` 中的依赖,请通过运行如下指令额外安装 Waymo 数据集所依赖的官方包:
``` ```
# tf 2.1.0. pip install waymo-open-dataset-tf-2-6-0
pip install waymo-open-dataset-tf-2-1-0==1.2.0
# tf 2.0.0
# pip install waymo-open-dataset-tf-2-0-0==1.2.0
# tf 1.15.0
# pip install waymo-open-dataset-tf-1-15-0==1.2.0
``` ```
或者 或者
...@@ -38,6 +33,8 @@ mmdetection3d ...@@ -38,6 +33,8 @@ mmdetection3d
│ │ │ ├── validation │ │ │ ├── validation
│ │ │ ├── testing │ │ │ ├── testing
│ │ │ ├── gt.bin │ │ │ ├── gt.bin
│ │ │ ├── cam_gt.bin
│ │ │ ├── fov_gt.bin
│ │ ├── kitti_format │ │ ├── kitti_format
│ │ │ ├── ImageSets │ │ │ ├── ImageSets
...@@ -46,7 +43,9 @@ mmdetection3d ...@@ -46,7 +43,9 @@ mmdetection3d
您可以在[这里](https://waymo.com/open/download/)下载 1.2 版本的 Waymo 公开数据集,并在[这里](https://drive.google.com/drive/folders/18BVuF_RYJF0NjZpt8SnfzANiakoRMf0o?usp=sharing)下载其训练/验证/测试集拆分文件。接下来,请将 `tfrecord` 文件放入 `data/waymo/waymo_format/` 下的对应文件夹,并将 txt 格式的数据集拆分文件放入 `data/waymo/kitti_format/ImageSets`。在[这里](https://console.cloud.google.com/storage/browser/waymo_open_dataset_v_1_2_0/validation/ground_truth_objects)下载验证集使用的 bin 格式真实标注 (Ground Truth) 文件并放入 `data/waymo/waymo_format/`。小窍门:您可以使用 `gsutil` 来在命令行下载大规模数据集。您可以将该[工具](https://github.com/RalphMao/Waymo-Dataset-Tool) 作为一个例子来查看更多细节。之后,通过运行如下指令准备 Waymo 数据: 您可以在[这里](https://waymo.com/open/download/)下载 1.2 版本的 Waymo 公开数据集,并在[这里](https://drive.google.com/drive/folders/18BVuF_RYJF0NjZpt8SnfzANiakoRMf0o?usp=sharing)下载其训练/验证/测试集拆分文件。接下来,请将 `tfrecord` 文件放入 `data/waymo/waymo_format/` 下的对应文件夹,并将 txt 格式的数据集拆分文件放入 `data/waymo/kitti_format/ImageSets`。在[这里](https://console.cloud.google.com/storage/browser/waymo_open_dataset_v_1_2_0/validation/ground_truth_objects)下载验证集使用的 bin 格式真实标注 (Ground Truth) 文件并放入 `data/waymo/waymo_format/`。小窍门:您可以使用 `gsutil` 来在命令行下载大规模数据集。您可以将该[工具](https://github.com/RalphMao/Waymo-Dataset-Tool) 作为一个例子来查看更多细节。之后,通过运行如下指令准备 Waymo 数据:
```bash ```bash
python tools/create_data.py waymo --root-path ./data/waymo/ --out-dir ./data/waymo/ --workers 128 --extra-tag waymo # TF_CPP_MIN_LOG_LEVEL=3 will disable all logging output from TensorFlow.
# The number of `--workers` depends on the maximum number of cores in your CPU.
TF_CPP_MIN_LOG_LEVEL=3 python tools/create_data.py waymo --root-path ./data/waymo --out-dir ./data/waymo --workers 128 --extra-tag waymo --version v1.4
``` ```
请注意,如果您的本地磁盘没有足够空间保存转换后的数据,您可以将 `--out-dir` 改为其他目录;只要在创建文件夹、准备数据并转换格式后,将数据文件链接到 `data/waymo/kitti_format` 即可。 请注意,如果您的本地磁盘没有足够空间保存转换后的数据,您可以将 `--out-dir` 改为其他目录;只要在创建文件夹、准备数据并转换格式后,将数据文件链接到 `data/waymo/kitti_format` 即可。
...@@ -65,22 +64,16 @@ mmdetection3d ...@@ -65,22 +64,16 @@ mmdetection3d
│ │ │ ├── validation │ │ │ ├── validation
│ │ │ ├── testing │ │ │ ├── testing
│ │ │ ├── gt.bin │ │ │ ├── gt.bin
│ │ │ ├── cam_gt.bin
│ │ │ ├── fov_gt.bin
│ │ ├── kitti_format │ │ ├── kitti_format
│ │ │ ├── ImageSets │ │ │ ├── ImageSets
│ │ │ ├── training │ │ │ ├── training
│ │ │ │ ├── calib
│ │ │ │ ├── image_0 │ │ │ │ ├── image_0
│ │ │ │ ├── image_1 │ │ │ │ ├── image_1
│ │ │ │ ├── image_2 │ │ │ │ ├── image_2
│ │ │ │ ├── image_3 │ │ │ │ ├── image_3
│ │ │ │ ├── image_4 │ │ │ │ ├── image_4
│ │ │ │ ├── label_0
│ │ │ │ ├── label_1
│ │ │ │ ├── label_2
│ │ │ │ ├── label_3
│ │ │ │ ├── label_4
│ │ │ │ ├── label_all
│ │ │ │ ├── pose
│ │ │ │ ├── velodyne │ │ │ │ ├── velodyne
│ │ │ ├── testing │ │ │ ├── testing
│ │ │ │ ├── (the same as training) │ │ │ │ ├── (the same as training)
...@@ -93,7 +86,48 @@ mmdetection3d ...@@ -93,7 +86,48 @@ mmdetection3d
``` ```
因为 Waymo 数据的来源包含数个相机,这里我们将每个相机对应的图像和标签文件分别存储,并将相机位姿 (pose) 文件存储下来以供后续处理连续多帧的点云。我们使用 `{a}{bbb}{ccc}` 的名称编码方式为每帧数据命名,其中 `a` 是不同数据拆分的前缀(`0` 指代训练集,`1` 指代验证集,`2` 指代测试集),`bbb` 是分割部分 (segment) 的索引,而 `ccc` 是帧索引。您可以轻而易举地按照如上命名规则定位到所需的帧。我们将训练和验证所需数据按 KITTI 的方式集合在一起,然后将训练集/验证集/测试集的索引存储在 `ImageSet` 下的文件中。 - `kitti_format/training/image_{0-4}/{a}{bbb}{ccc}.jpg` 因为 Waymo 数据的来源包含数个相机,这里我们将每个相机对应的图像和标签文件分别存储,并将相机位姿 (pose) 文件存储下来以供后续处理连续多帧的点云。我们使用 `{a}{bbb}{ccc}` 的名称编码方式为每帧数据命名,其中 `a` 是不同数据拆分的前缀(`0` 指代训练集,`1` 指代验证集,`2` 指代测试集),`bbb` 是分割部分 (segment) 的索引,而 `ccc` 是帧索引。您可以轻而易举地按照如上命名规则定位到所需的帧。我们将训练和验证所需数据按 KITTI 的方式集合在一起,然后将训练集/验证集/测试集的索引存储在 `ImageSet` 下的文件中。
- `kitti_format/training/velodyne/{a}{bbb}{ccc}.bin` 当前样本的点云数据
- `kitti_format/waymo_gt_database/xxx_{Car/Pedestrian/Cyclist}_x.bin`. 训练数据集的每个 3D 包围框中包含的点云数据。这些点云会在数据增强中被使用,例如. `ObjectSample`. `xxx` 表示训练样本的索引,`x` 表示实例在当前样本中的索引。
- `kitti_format/waymo_infos_train.pkl`. 训练数据集,该字典包含了两个键值:`metainfo``data_list``metainfo` 包含数据集的基本信息,例如 `dataset``version``info_version``data_list` 是由字典组成的列表,每个字典(以下简称 `info`)包含了单个样本的所有详细信息。:
- info\['sample_idx'\]: 样本在整个数据集的索引。
- info\['ego2global'\]: 自车到全局坐标的变换矩阵。(4x4 列表)
- info\['timestamp'\]:样本数据时间戳。
- info\['context_name'\]: 语境名,表示样本从哪个 `*.tfrecord` 片段中提取的。
- info\['lidar_points'\]: 是一个字典,包含了所有与激光雷达点相关的信息。
- info\['lidar_points'\]\['lidar_path'\]: 激光雷达点云数据的文件名。
- info\['lidar_points'\]\['num_pts_feats'\]: 点的特征维度。
- info\['lidar_sweeps'\]: 是一个列表,包含了历史帧信息。
- info\['lidar_sweeps'\]\[i\]\['lidar_points'\]\['lidar_path'\]: 第 i 帧的激光雷达数据的文件路径。
- info\['lidar_sweeps'\]\[i\]\['ego2global'\]: 第 i 帧的激光雷达传感器到自车的变换矩阵。(4x4 列表)
- info\['lidar_sweeps'\]\[i\]\['timestamp'\]: 第 i 帧的样本数据时间戳。
- info\['images'\]: 是一个字典,包含与每个相机对应的六个键值:`'CAM_FRONT'`, `'CAM_FRONT_RIGHT'`, `'CAM_FRONT_LEFT'`, `'CAM_SIDE_LEFT'`, `'CAM_SIDE_RIGHT'`。每个字典包含了对应相机的所有数据信息。
- info\['images'\]\['CAM_XXX'\]\['img_path'\]: 图像的文件名。
- info\['images'\]\['CAM_XXX'\]\['height'\]: 图像的高
- info\['images'\]\['CAM_XXX'\]\['width'\]: 图像的宽
- info\['images'\]\['CAM_XXX'\]\['cam2img'\]: 当 3D 点投影到图像平面时需要的内参信息相关的变换矩阵。(3x3 列表)
- info\['images'\]\['CAM_XXX'\]\['lidar2cam'\]: 激光雷达传感器到该相机的变换矩阵。(4x4 列表)
- info\['images'\]\['CAM_XXX'\]\['lidar2img'\]: 激光雷达传感器到图像平面的变换矩阵。(4x4 列表)
- info\['image_sweeps'\]: 是一个列表,包含了历史帧信息。
- info\['image_sweeps'\]\[i\]\['images'\]\['CAM_XXX'\]\['img_path'\]: 第i帧的图像的文件名.
- info\['image_sweeps'\]\[i\]\['ego2global'\]: 第 i 帧的自车到全局坐标的变换矩阵。(4x4 列表)
- info\['image_sweeps'\]\[i\]\['timestamp'\]: 第 i 帧的样本数据时间戳。
- info\['instances'\]: 是一个字典组成的列表。每个字典包含单个实例的所有标注信息。对于其中的第 i 个实例,我们有:
- info\['instances'\]\[i\]\['bbox_3d'\]: 长度为 7 的列表,以 (x, y, z, l, w, h, yaw) 的顺序表示实例的 3D 边界框。
- info\['instances'\]\[i\]\['bbox'\]: 2D 边界框标注(,顺序为 \[x1, y1, x2, y2\] 的列表。有些实例可能没有对应的 2D 边界框标注。
- info\['instances'\]\[i\]\['bbox_label_3d'\]: 整数表示实例的标签,-1 代表忽略。
- info\['instances'\]\[i\]\['bbox_label'\]: 整数表示实例的标签,-1 代表忽略。
- info\['instances'\]\[i\]\['num_lidar_pts'\]: 每个 3D 边界框内包含的激光雷达点数。
- info\['instances'\]\[i\]\['camera_id'\]: 当前实例最可见相机的索引。
- info\['instances'\]\[i\]\['group_id'\]: 当前实例在当前样本中的索引。
- info\['cam_sync_instances'\]: 是一个字典组成的列表。每个字典包含单个实例的所有标注信息。它的形式与 \['instances'\]相同. 但是, \['cam_sync_instances'\] 专门用于基于多视角相机的三维目标检测任务。
- info\['cam_instances'\]: 是一个字典,包含以下键值: `'CAM_FRONT'`, `'CAM_FRONT_RIGHT'`, `'CAM_FRONT_LEFT'`, `'CAM_SIDE_LEFT'`, `'CAM_SIDE_RIGHT'`. 对于基于视觉的 3D 目标检测任务,我们将整个场景的 3D 标注划分至它们所属于的相应相机中。对于其中的第 i 个实例,我们有:
- info\['cam_instances'\]\['CAM_XXX'\]\[i\]\['bbox_3d'\]: 长度为 7 的列表,以 (x, y, z, l, h, w, yaw) 的顺序表示实例的 3D 边界框。
- info\['cam_instances'\]\['CAM_XXX'\]\[i\]\['bbox'\]: 2D 边界框标注(3D 框投影的矩形框),顺序为 \[x1, y1, x2, y2\] 的列表。
- info\['cam_instances'\]\['CAM_XXX'\]\[i\]\['bbox_label_3d'\]: 实例标签。
- info\['cam_instances'\]\['CAM_XXX'\]\[i\]\['bbox_label'\]: 实例标签。
- info\['cam_instances'\]\['CAM_XXX'\]\[i\]\['center_2d'\]: 3D 框投影到图像上的中心点,大小为 (2, ) 的列表。
- info\['cam_instances'\]\['CAM_XXX'\]\[i\]\['depth'\]: 3D 框投影中心的深度。
## 训练 ## 训练
......
...@@ -113,7 +113,7 @@ class Det3DDataset(BaseDataset): ...@@ -113,7 +113,7 @@ class Det3DDataset(BaseDataset):
ori_label = self.METAINFO['classes'].index(name) ori_label = self.METAINFO['classes'].index(name)
self.label_mapping[ori_label] = label_idx self.label_mapping[ori_label] = label_idx
self.num_ins_per_cat = {name: 0 for name in metainfo['classes']} self.num_ins_per_cat = [0] * len(metainfo['classes'])
else: else:
self.label_mapping = { self.label_mapping = {
i: i i: i
...@@ -121,10 +121,7 @@ class Det3DDataset(BaseDataset): ...@@ -121,10 +121,7 @@ class Det3DDataset(BaseDataset):
} }
self.label_mapping[-1] = -1 self.label_mapping[-1] = -1
self.num_ins_per_cat = { self.num_ins_per_cat = [0] * len(self.METAINFO['classes'])
name: 0
for name in self.METAINFO['classes']
}
super().__init__( super().__init__(
ann_file=ann_file, ann_file=ann_file,
...@@ -146,9 +143,12 @@ class Det3DDataset(BaseDataset): ...@@ -146,9 +143,12 @@ class Det3DDataset(BaseDataset):
# show statistics of this dataset # show statistics of this dataset
print_log('-' * 30, 'current') print_log('-' * 30, 'current')
print_log(f'The length of the dataset: {len(self)}', 'current') print_log(
f'The length of {"test" if self.test_mode else "training"} dataset: {len(self)}', # noqa: E501
'current')
content_show = [['category', 'number']] content_show = [['category', 'number']]
for cat_name, num in self.num_ins_per_cat.items(): for label, num in enumerate(self.num_ins_per_cat):
cat_name = self.metainfo['classes'][label]
content_show.append([cat_name, num]) content_show.append([cat_name, num])
table = AsciiTable(content_show) table = AsciiTable(content_show)
print_log( print_log(
...@@ -256,8 +256,7 @@ class Det3DDataset(BaseDataset): ...@@ -256,8 +256,7 @@ class Det3DDataset(BaseDataset):
for label in ann_info['gt_labels_3d']: for label in ann_info['gt_labels_3d']:
if label != -1: if label != -1:
cat_name = self.metainfo['classes'][label] self.num_ins_per_cat[label] += 1
self.num_ins_per_cat[cat_name] += 1
return ann_info return ann_info
......
...@@ -3,9 +3,11 @@ import os.path as osp ...@@ -3,9 +3,11 @@ import os.path as osp
from typing import Callable, List, Union from typing import Callable, List, Union
import numpy as np import numpy as np
from mmengine import print_log
from mmengine.fileio import load
from mmdet3d.registry import DATASETS from mmdet3d.registry import DATASETS
from mmdet3d.structures import CameraInstance3DBoxes from mmdet3d.structures import CameraInstance3DBoxes, LiDARInstance3DBoxes
from .det3d_dataset import Det3DDataset from .det3d_dataset import Det3DDataset
from .kitti_dataset import KittiDataset from .kitti_dataset import KittiDataset
...@@ -163,13 +165,10 @@ class WaymoDataset(KittiDataset): ...@@ -163,13 +165,10 @@ class WaymoDataset(KittiDataset):
centers_2d = np.zeros((0, 2), dtype=np.float32) centers_2d = np.zeros((0, 2), dtype=np.float32)
depths = np.zeros((0), dtype=np.float32) depths = np.zeros((0), dtype=np.float32)
# in waymo, lidar2cam = R0_rect @ Tr_velo_to_cam if self.load_type == 'frame_based':
# convert gt_bboxes_3d to velodyne coordinates with `lidar2cam` gt_bboxes_3d = LiDARInstance3DBoxes(ann_info['gt_bboxes_3d'])
lidar2cam = np.array(info['images'][self.default_cam_key]['lidar2cam']) else:
gt_bboxes_3d = CameraInstance3DBoxes( gt_bboxes_3d = CameraInstance3DBoxes(ann_info['gt_bboxes_3d'])
ann_info['gt_bboxes_3d']).convert_to(self.box_mode_3d,
np.linalg.inv(lidar2cam))
ann_info['gt_bboxes_3d'] = gt_bboxes_3d
anns_results = dict( anns_results = dict(
gt_bboxes_3d=gt_bboxes_3d, gt_bboxes_3d=gt_bboxes_3d,
...@@ -182,9 +181,58 @@ class WaymoDataset(KittiDataset): ...@@ -182,9 +181,58 @@ class WaymoDataset(KittiDataset):
return anns_results return anns_results
def load_data_list(self) -> List[dict]: def load_data_list(self) -> List[dict]:
"""Add the load interval.""" """Add the load interval.
data_list = super().load_data_list()
data_list = data_list[::self.load_interval] Returns:
list[dict]: A list of annotation.
""" # noqa: E501
# `self.ann_file` denotes the absolute annotation file path if
# `self.root=None` or relative path if `self.root=/path/to/data/`.
annotations = load(self.ann_file)
if not isinstance(annotations, dict):
raise TypeError(f'The annotations loaded from annotation file '
f'should be a dict, but got {type(annotations)}!')
if 'data_list' not in annotations or 'metainfo' not in annotations:
raise ValueError('Annotation must have data_list and metainfo '
'keys')
metainfo = annotations['metainfo']
raw_data_list = annotations['data_list']
raw_data_list = raw_data_list[::self.load_interval]
if self.load_interval > 1:
print_log(
f'Sample size will be reduced to 1/{self.load_interval} of'
' the original data sample',
logger='current')
# Meta information load from annotation file will not influence the
# existed meta information load from `BaseDataset.METAINFO` and
# `metainfo` arguments defined in constructor.
for k, v in metainfo.items():
self._metainfo.setdefault(k, v)
# load and parse data_infos.
data_list = []
for raw_data_info in raw_data_list:
# parse raw data information to target format
data_info = self.parse_data_info(raw_data_info)
if isinstance(data_info, dict):
# For image tasks, `data_info` should information if single
# image, such as dict(img_path='xxx', width=360, ...)
data_list.append(data_info)
elif isinstance(data_info, list):
# For video tasks, `data_info` could contain image
# information of multiple frames, such as
# [dict(video_path='xxx', timestamps=...),
# dict(video_path='xxx', timestamps=...)]
for item in data_info:
if not isinstance(item, dict):
raise TypeError('data_info must be list of dict, but '
f'got {type(item)}')
data_list.extend(data_info)
else:
raise TypeError('data_info should be a dict or list of dict, '
f'but got {type(data_info)}')
return data_list return data_list
def parse_data_info(self, info: dict) -> Union[dict, List[dict]]: def parse_data_info(self, info: dict) -> Union[dict, List[dict]]:
...@@ -203,44 +251,39 @@ class WaymoDataset(KittiDataset): ...@@ -203,44 +251,39 @@ class WaymoDataset(KittiDataset):
info['images'][self.default_cam_key] info['images'][self.default_cam_key]
info['images'] = new_image_info info['images'] = new_image_info
info['instances'] = info['cam_instances'][self.default_cam_key] info['instances'] = info['cam_instances'][self.default_cam_key]
return super().parse_data_info(info) return Det3DDataset.parse_data_info(self, info)
else: else:
# in the mono3d, the instances is from cam sync. # in the mono3d, the instances is from cam sync.
# Convert frame-based infos to multi-view image-based
data_list = [] data_list = []
if self.modality['use_lidar']:
info['lidar_points']['lidar_path'] = \
osp.join(
self.data_prefix.get('pts', ''),
info['lidar_points']['lidar_path'])
if self.modality['use_camera']:
for cam_key, img_info in info['images'].items():
if 'img_path' in img_info:
cam_prefix = self.data_prefix.get(cam_key, '')
img_info['img_path'] = osp.join(
cam_prefix, img_info['img_path'])
for (cam_key, img_info) in info['images'].items(): for (cam_key, img_info) in info['images'].items():
camera_info = dict() camera_info = dict()
camera_info['sample_idx'] = info['sample_idx']
camera_info['timestamp'] = info['timestamp']
camera_info['context_name'] = info['context_name']
camera_info['images'] = dict() camera_info['images'] = dict()
camera_info['images'][cam_key] = img_info camera_info['images'][cam_key] = img_info
if 'cam_instances' in info \ if 'img_path' in img_info:
and cam_key in info['cam_instances']: cam_prefix = self.data_prefix.get(cam_key, '')
camera_info['instances'] = info['cam_instances'][cam_key] camera_info['images'][cam_key]['img_path'] = osp.join(
cam_prefix, img_info['img_path'])
if 'lidar2cam' in img_info:
camera_info['lidar2cam'] = np.array(img_info['lidar2cam'])
if 'cam2img' in img_info:
camera_info['cam2img'] = np.array(img_info['cam2img'])
if 'lidar2img' in img_info:
camera_info['lidar2img'] = np.array(img_info['lidar2img'])
else: else:
camera_info['instances'] = [] camera_info['lidar2img'] = camera_info[
camera_info['ego2global'] = info['ego2global'] 'cam2img'] @ camera_info['lidar2cam']
if 'image_sweeps' in info:
camera_info['image_sweeps'] = info['image_sweeps']
# TODO check if need to modify the sample id
# TODO check when will use it except for evaluation.
camera_info['sample_idx'] = info['sample_idx']
if not self.test_mode: if not self.test_mode:
# used in training # used in training
camera_info['instances'] = info['cam_instances'][cam_key]
camera_info['ann_info'] = self.parse_ann_info(camera_info) camera_info['ann_info'] = self.parse_ann_info(camera_info)
if self.test_mode and self.load_eval_anns: if self.test_mode and self.load_eval_anns:
info['eval_ann_info'] = self.parse_ann_info(info) camera_info['instances'] = info['cam_instances'][cam_key]
camera_info['eval_ann_info'] = self.parse_ann_info(
camera_info)
data_list.append(camera_info) data_list.append(camera_info)
return data_list return data_list
...@@ -78,11 +78,11 @@ class Det3DVisualizationHook(Hook): ...@@ -78,11 +78,11 @@ class Det3DVisualizationHook(Hook):
'needs to be excluded.') 'needs to be excluded.')
self.vis_task = vis_task self.vis_task = vis_task
if wait_time == -1: if show and wait_time == -1:
print_log( print_log(
'Manual control mode, press [Right] to next sample.', 'Manual control mode, press [Right] to next sample.',
logger='current') logger='current')
else: elif show:
print_log( print_log(
'Autoplay mode, press [SPACE] to pause.', logger='current') 'Autoplay mode, press [SPACE] to pause.', logger='current')
self.wait_time = wait_time self.wait_time = wait_time
......
...@@ -4,7 +4,6 @@ r"""Adapted from `Waymo to KITTI converter ...@@ -4,7 +4,6 @@ r"""Adapted from `Waymo to KITTI converter
""" """
try: try:
from waymo_open_dataset import dataset_pb2 as open_dataset
from waymo_open_dataset import label_pb2 from waymo_open_dataset import label_pb2
from waymo_open_dataset.protos import metrics_pb2 from waymo_open_dataset.protos import metrics_pb2
from waymo_open_dataset.protos.metrics_pb2 import Objects from waymo_open_dataset.protos.metrics_pb2 import Objects
...@@ -14,13 +13,10 @@ except ImportError: ...@@ -14,13 +13,10 @@ except ImportError:
'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" ' 'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" '
'to install the official devkit first.') 'to install the official devkit first.')
from glob import glob from typing import List
from os.path import join
from typing import List, Optional
import mmengine import mmengine
import numpy as np from mmengine import print_log
import tensorflow as tf
class Prediction2Waymo(object): class Prediction2Waymo(object):
...@@ -32,54 +28,22 @@ class Prediction2Waymo(object): ...@@ -32,54 +28,22 @@ class Prediction2Waymo(object):
Args: Args:
results (list[dict]): Prediction results. results (list[dict]): Prediction results.
waymo_tfrecords_dir (str): Directory to load waymo raw data.
waymo_results_save_dir (str): Directory to save converted predictions waymo_results_save_dir (str): Directory to save converted predictions
in waymo format (.bin files). in waymo format (.bin files).
waymo_results_final_path (str): Path to save combined waymo_results_final_path (str): Path to save combined
predictions in waymo format (.bin file), like 'a/b/c.bin'. predictions in waymo format (.bin file), like 'a/b/c.bin'.
prefix (str): Prefix of filename. In general, 0 for training, 1 for num_workers (str): Number of parallel processes. Defaults to 4.
validation and 2 for testing.
classes (dict): A list of class name.
workers (str): Number of parallel processes. Defaults to 2.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
from_kitti_format (bool, optional): Whether the reuslts are kitti
format. Defaults to False.
idx2metainfo (Optional[dict], optional): The mapping from sample_idx to
metainfo. The metainfo must contain the keys: 'idx2contextname' and
'idx2timestamp'. Defaults to None.
""" """
def __init__(self, def __init__(self,
results: List[dict], results: List[dict],
waymo_tfrecords_dir: str,
waymo_results_save_dir: str,
waymo_results_final_path: str, waymo_results_final_path: str,
prefix: str,
classes: dict, classes: dict,
workers: int = 2, num_workers: int = 4):
backend_args: Optional[dict] = None,
from_kitti_format: bool = False,
idx2metainfo: Optional[dict] = None):
self.results = results self.results = results
self.waymo_tfrecords_dir = waymo_tfrecords_dir
self.waymo_results_save_dir = waymo_results_save_dir
self.waymo_results_final_path = waymo_results_final_path self.waymo_results_final_path = waymo_results_final_path
self.prefix = prefix
self.classes = classes self.classes = classes
self.workers = int(workers) self.num_workers = num_workers
self.backend_args = backend_args
self.from_kitti_format = from_kitti_format
if idx2metainfo is not None:
self.idx2metainfo = idx2metainfo
# If ``fast_eval``, the metainfo does not need to be read from
# original data online. It's preprocessed offline.
self.fast_eval = True
else:
self.fast_eval = False
self.name2idx = {}
self.k2w_cls_map = { self.k2w_cls_map = {
'Car': label_pb2.Label.TYPE_VEHICLE, 'Car': label_pb2.Label.TYPE_VEHICLE,
...@@ -88,213 +52,23 @@ class Prediction2Waymo(object): ...@@ -88,213 +52,23 @@ class Prediction2Waymo(object):
'Cyclist': label_pb2.Label.TYPE_CYCLIST, 'Cyclist': label_pb2.Label.TYPE_CYCLIST,
} }
if self.from_kitti_format: def convert_one(self, res_idx: int):
self.T_ref_to_front_cam = np.array([[0.0, 0.0, 1.0, 0.0],
[-1.0, 0.0, 0.0, 0.0],
[0.0, -1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0]])
# ``sample_idx`` of the sample in kitti-format is an array
for idx, result in enumerate(results):
if len(result['sample_idx']) > 0:
self.name2idx[str(result['sample_idx'][0])] = idx
else:
# ``sample_idx`` of the sample in the original prediction
# is an int value.
for idx, result in enumerate(results):
self.name2idx[str(result['sample_idx'])] = idx
if not self.fast_eval:
# need to read original '.tfrecord' file
self.get_file_names()
# turn on eager execution for older tensorflow versions
if int(tf.__version__.split('.')[0]) < 2:
tf.enable_eager_execution()
self.create_folder()
def get_file_names(self):
"""Get file names of waymo raw data."""
if 'path_mapping' in self.backend_args:
for path in self.backend_args['path_mapping'].keys():
if path in self.waymo_tfrecords_dir:
self.waymo_tfrecords_dir = \
self.waymo_tfrecords_dir.replace(
path, self.backend_args['path_mapping'][path])
from petrel_client.client import Client
client = Client()
contents = client.list(self.waymo_tfrecords_dir)
self.waymo_tfrecord_pathnames = list()
for content in sorted(list(contents)):
if content.endswith('tfrecord'):
self.waymo_tfrecord_pathnames.append(
join(self.waymo_tfrecords_dir, content))
else:
self.waymo_tfrecord_pathnames = sorted(
glob(join(self.waymo_tfrecords_dir, '*.tfrecord')))
print(len(self.waymo_tfrecord_pathnames), 'tfrecords found.')
def create_folder(self):
"""Create folder for data conversion."""
mmengine.mkdir_or_exist(self.waymo_results_save_dir)
def parse_objects(self, kitti_result, T_k2w, context_name,
frame_timestamp_micros):
"""Parse one prediction with several instances in kitti format and
convert them to `Object` proto.
Args:
kitti_result (dict): Predictions in kitti format.
- name (np.ndarray): Class labels of predictions.
- dimensions (np.ndarray): Height, width, length of boxes.
- location (np.ndarray): Bottom center of boxes (x, y, z).
- rotation_y (np.ndarray): Orientation of boxes.
- score (np.ndarray): Scores of predictions.
T_k2w (np.ndarray): Transformation matrix from kitti to waymo.
context_name (str): Context name of the frame.
frame_timestamp_micros (int): Frame timestamp.
Returns:
:obj:`Object`: Predictions in waymo dataset Object proto.
"""
def parse_one_object(instance_idx):
"""Parse one instance in kitti format and convert them to `Object`
proto.
Args:
instance_idx (int): Index of the instance to be converted.
Returns:
:obj:`Object`: Predicted instance in waymo dataset
Object proto.
"""
cls = kitti_result['name'][instance_idx]
length = round(kitti_result['dimensions'][instance_idx, 0], 4)
height = round(kitti_result['dimensions'][instance_idx, 1], 4)
width = round(kitti_result['dimensions'][instance_idx, 2], 4)
x = round(kitti_result['location'][instance_idx, 0], 4)
y = round(kitti_result['location'][instance_idx, 1], 4)
z = round(kitti_result['location'][instance_idx, 2], 4)
rotation_y = round(kitti_result['rotation_y'][instance_idx], 4)
score = round(kitti_result['score'][instance_idx], 4)
# y: downwards; move box origin from bottom center (kitti) to
# true center (waymo)
y -= height / 2
# frame transformation: kitti -> waymo
x, y, z = self.transform(T_k2w, x, y, z)
# different conventions
heading = -(rotation_y + np.pi / 2)
while heading < -np.pi:
heading += 2 * np.pi
while heading > np.pi:
heading -= 2 * np.pi
box = label_pb2.Label.Box()
box.center_x = x
box.center_y = y
box.center_z = z
box.length = length
box.width = width
box.height = height
box.heading = heading
o = metrics_pb2.Object()
o.object.box.CopyFrom(box)
o.object.type = self.k2w_cls_map[cls]
o.score = score
o.context_name = context_name
o.frame_timestamp_micros = frame_timestamp_micros
return o
objects = metrics_pb2.Objects()
for instance_idx in range(len(kitti_result['name'])):
o = parse_one_object(instance_idx)
objects.objects.append(o)
return objects
def convert_one(self, file_idx):
"""Convert action for single file.
Args:
file_idx (int): Index of the file to be converted.
"""
file_pathname = self.waymo_tfrecord_pathnames[file_idx]
if 's3://' in file_pathname and tf.__version__ >= '2.6.0':
try:
import tensorflow_io as tfio # noqa: F401
except ImportError:
raise ImportError(
"Please run 'pip install tensorflow-io' to install tensorflow_io first." # noqa: E501
)
file_data = tf.data.TFRecordDataset(file_pathname, compression_type='')
for frame_num, frame_data in enumerate(file_data):
frame = open_dataset.Frame()
frame.ParseFromString(bytearray(frame_data.numpy()))
filename = f'{self.prefix}{file_idx:03d}{frame_num:03d}'
context_name = frame.context.name
frame_timestamp_micros = frame.timestamp_micros
if filename in self.name2idx:
if self.from_kitti_format:
for camera in frame.context.camera_calibrations:
# FRONT = 1, see dataset.proto for details
if camera.name == 1:
T_front_cam_to_vehicle = np.array(
camera.extrinsic.transform).reshape(4, 4)
T_k2w = T_front_cam_to_vehicle @ self.T_ref_to_front_cam
kitti_result = \
self.results[self.name2idx[filename]]
objects = self.parse_objects(kitti_result, T_k2w,
context_name,
frame_timestamp_micros)
else:
index = self.name2idx[filename]
objects = self.parse_objects_from_origin(
self.results[index], context_name,
frame_timestamp_micros)
else:
print(filename, 'not found.')
objects = metrics_pb2.Objects()
with open(
join(self.waymo_results_save_dir, f'{filename}.bin'),
'wb') as f:
f.write(objects.SerializeToString())
def convert_one_fast(self, res_index: int):
"""Convert action for single file. It read the metainfo from the """Convert action for single file. It read the metainfo from the
preprocessed file offline and will be faster. preprocessed file offline and will be faster.
Args: Args:
res_index (int): The indices of the results. res_idx (int): The indices of the results.
""" """
sample_idx = self.results[res_index]['sample_idx'] sample_idx = self.results[res_idx]['sample_idx']
if len(self.results[res_index]['pred_instances_3d']) > 0: if len(self.results[res_idx]['labels_3d']) > 0:
objects = self.parse_objects_from_origin( objects = self.parse_objects_from_origin(
self.results[res_index], self.results[res_idx], self.results[res_idx]['context_name'],
self.idx2metainfo[str(sample_idx)]['contextname'], self.results[res_idx]['timestamp'])
self.idx2metainfo[str(sample_idx)]['timestamp'])
else: else:
print(sample_idx, 'not found.') print(sample_idx, 'not found.')
objects = metrics_pb2.Objects() objects = metrics_pb2.Objects()
with open( return objects
join(self.waymo_results_save_dir, f'{sample_idx}.bin'),
'wb') as f:
f.write(objects.SerializeToString())
def parse_objects_from_origin(self, result: dict, contextname: str, def parse_objects_from_origin(self, result: dict, contextname: str,
timestamp: str) -> Objects: timestamp: str) -> Objects:
...@@ -308,112 +82,56 @@ class Prediction2Waymo(object): ...@@ -308,112 +82,56 @@ class Prediction2Waymo(object):
Returns: Returns:
metrics_pb2.Objects: The parsed object. metrics_pb2.Objects: The parsed object.
""" """
lidar_boxes = result['pred_instances_3d']['bboxes_3d'].tensor lidar_boxes = result['bboxes_3d']
scores = result['pred_instances_3d']['scores_3d'] scores = result['scores_3d']
labels = result['pred_instances_3d']['labels_3d'] labels = result['labels_3d']
def parse_one_object(index):
class_name = self.classes[labels[index].item()]
objects = metrics_pb2.Objects()
for lidar_box, score, label in zip(lidar_boxes, scores, labels):
# Parse one object
box = label_pb2.Label.Box() box = label_pb2.Label.Box()
height = lidar_boxes[index][5].item() height = lidar_box[5]
heading = lidar_boxes[index][6].item() heading = lidar_box[6]
while heading < -np.pi: box.center_x = lidar_box[0]
heading += 2 * np.pi box.center_y = lidar_box[1]
while heading > np.pi: box.center_z = lidar_box[2] + height / 2
heading -= 2 * np.pi box.length = lidar_box[3]
box.width = lidar_box[4]
box.center_x = lidar_boxes[index][0].item()
box.center_y = lidar_boxes[index][1].item()
box.center_z = lidar_boxes[index][2].item() + height / 2
box.length = lidar_boxes[index][3].item()
box.width = lidar_boxes[index][4].item()
box.height = height box.height = height
box.heading = heading box.heading = heading
o = metrics_pb2.Object() object = metrics_pb2.Object()
o.object.box.CopyFrom(box) object.object.box.CopyFrom(box)
o.object.type = self.k2w_cls_map[class_name]
o.score = scores[index].item()
o.context_name = contextname
o.frame_timestamp_micros = timestamp
return o class_name = self.classes[label]
object.object.type = self.k2w_cls_map[class_name]
objects = metrics_pb2.Objects() object.score = score
for i in range(len(lidar_boxes)): object.context_name = contextname
objects.objects.append(parse_one_object(i)) object.frame_timestamp_micros = timestamp
objects.objects.append(object)
return objects return objects
def convert(self): def convert(self):
"""Convert action.""" """Convert action."""
print('Start converting ...') print_log('Start converting ...', logger='current')
convert_func = self.convert_one_fast if self.fast_eval else \
self.convert_one
# from torch.multiprocessing import set_sharing_strategy # TODO: use parallel processes.
# # Force using "file_system" sharing strategy for stability # objects_list = mmengine.track_parallel_progress(
# set_sharing_strategy("file_system") # self.convert_one, range(len(self)), self.num_workers)
# mmengine.track_parallel_progress(convert_func, range(len(self)), objects_list = mmengine.track_progress(self.convert_one,
# self.workers) range(len(self)))
# TODO: Support multiprocessing. Now, multiprocessing evaluation will combined = metrics_pb2.Objects()
# cause shared memory error in torch-1.10 and torch-1.11. Details can for objects in objects_list:
# be seen in https://github.com/pytorch/pytorch/issues/67864. for o in objects.objects:
prog_bar = mmengine.ProgressBar(len(self)) combined.objects.append(o)
for i in range(len(self)):
convert_func(i)
prog_bar.update()
print('\nFinished ...')
# combine all files into one .bin
pathnames = sorted(glob(join(self.waymo_results_save_dir, '*.bin')))
combined = self.combine(pathnames)
with open(self.waymo_results_final_path, 'wb') as f: with open(self.waymo_results_final_path, 'wb') as f:
f.write(combined.SerializeToString()) f.write(combined.SerializeToString())
def __len__(self): def __len__(self):
"""Length of the filename list.""" """Length of the filename list."""
return len(self.results) if self.fast_eval else len( return len(self.results)
self.waymo_tfrecord_pathnames)
def transform(self, T, x, y, z):
"""Transform the coordinates with matrix T.
Args:
T (np.ndarray): Transformation matrix.
x(float): Coordinate in x axis.
y(float): Coordinate in y axis.
z(float): Coordinate in z axis.
Returns:
list: Coordinates after transformation.
"""
pt_bef = np.array([x, y, z, 1.0]).reshape(4, 1)
pt_aft = np.matmul(T, pt_bef)
return pt_aft[:3].flatten().tolist()
def combine(self, pathnames):
"""Combine predictions in waymo format for each sample together.
Args:
pathnames (str): Paths to save predictions.
Returns:
:obj:`Objects`: Combined predictions in Objects proto.
"""
combined = metrics_pb2.Objects()
for pathname in pathnames:
objects = metrics_pb2.Objects()
with open(pathname, 'rb') as f:
objects.ParseFromString(f.read())
for o in objects.objects:
combined.objects.append(o)
return combined
This diff is collapsed.
...@@ -179,7 +179,10 @@ test_pipeline = [ ...@@ -179,7 +179,10 @@ test_pipeline = [
dict( dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range) type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]), ]),
dict(type='Pack3DDetInputs', keys=['points']) dict(
type='Pack3DDetInputs',
keys=['points'],
meta_keys=['box_type_3d', 'sample_idx', 'context_name', 'timestamp'])
] ]
dataset_type = 'WaymoDataset' dataset_type = 'WaymoDataset'
...@@ -223,13 +226,7 @@ val_dataloader = dict( ...@@ -223,13 +226,7 @@ val_dataloader = dict(
test_dataloader = val_dataloader test_dataloader = val_dataloader
val_evaluator = dict( val_evaluator = dict(
type='WaymoMetric', type='WaymoMetric', waymo_bin_file='./data/waymo/waymo_format/gt.bin')
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
backend_args=backend_args,
convert_kitti_format=False,
idx2metainfo='./data/waymo/waymo_format/idx2metainfo.pkl')
test_evaluator = val_evaluator test_evaluator = val_evaluator
vis_backends = [dict(type='LocalVisBackend')] vis_backends = [dict(type='LocalVisBackend')]
......
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcv.transforms.base import BaseTransform
from mmengine.registry import TRANSFORMS
from mmengine.structures import InstanceData
from mmdet3d.datasets import WaymoDataset
from mmdet3d.structures import Det3DDataSample, LiDARInstance3DBoxes
def _generate_waymo_dataset_config():
data_root = 'tests/data/waymo/kitti_format'
ann_file = 'waymo_infos_train.pkl'
classes = ['Car', 'Pedestrian', 'Cyclist']
# wait for pipline refactor
if 'Identity' not in TRANSFORMS:
@TRANSFORMS.register_module()
class Identity(BaseTransform):
def transform(self, info):
if 'ann_info' in info:
info['gt_labels_3d'] = info['ann_info']['gt_labels_3d']
data_sample = Det3DDataSample()
gt_instances_3d = InstanceData()
gt_instances_3d.labels_3d = info['gt_labels_3d']
data_sample.gt_instances_3d = gt_instances_3d
info['data_samples'] = data_sample
return info
pipeline = [
dict(type='Identity'),
]
modality = dict(use_lidar=True, use_camera=True)
data_prefix = data_prefix = dict(
pts='training/velodyne', CAM_FRONT='training/image_0')
return data_root, ann_file, classes, data_prefix, pipeline, modality
def test_getitem():
data_root, ann_file, classes, data_prefix, \
pipeline, modality, = _generate_waymo_dataset_config()
waymo_dataset = WaymoDataset(
data_root,
ann_file,
data_prefix=data_prefix,
pipeline=pipeline,
metainfo=dict(classes=classes),
modality=modality)
waymo_dataset.prepare_data(0)
input_dict = waymo_dataset.get_data_info(0)
waymo_dataset[0]
# assert the the path should contains data_prefix and data_root
assert data_prefix['pts'] in input_dict['lidar_points']['lidar_path']
assert data_root in input_dict['lidar_points']['lidar_path']
for cam_id, img_info in input_dict['images'].items():
if 'img_path' in img_info:
assert data_prefix['CAM_FRONT'] in img_info['img_path']
assert data_root in img_info['img_path']
ann_info = waymo_dataset.parse_ann_info(input_dict)
# only one instance
assert 'gt_labels_3d' in ann_info
assert ann_info['gt_labels_3d'].dtype == np.int64
assert 'gt_bboxes_3d' in ann_info
assert isinstance(ann_info['gt_bboxes_3d'], LiDARInstance3DBoxes)
assert torch.allclose(ann_info['gt_bboxes_3d'].tensor.sum(),
torch.tensor(43.3103))
assert 'centers_2d' in ann_info
assert ann_info['centers_2d'].dtype == np.float32
assert 'depths' in ann_info
assert ann_info['depths'].dtype == np.float32
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
import argparse import argparse
from os import path as osp from os import path as osp
from mmengine import print_log
from tools.dataset_converters import indoor_converter as indoor from tools.dataset_converters import indoor_converter as indoor
from tools.dataset_converters import kitti_converter as kitti from tools.dataset_converters import kitti_converter as kitti
from tools.dataset_converters import lyft_converter as lyft_converter from tools.dataset_converters import lyft_converter as lyft_converter
...@@ -171,8 +173,19 @@ def waymo_data_prep(root_path, ...@@ -171,8 +173,19 @@ def waymo_data_prep(root_path,
version, version,
out_dir, out_dir,
workers, workers,
max_sweeps=5): max_sweeps=10,
"""Prepare the info file for waymo dataset. only_gt_database=False,
save_senor_data=False,
skip_cam_instances_infos=False):
"""Prepare waymo dataset. There are 3 steps as follows:
Step 1. Extract camera images and lidar point clouds from waymo raw
data in '*.tfreord' and save as kitti format.
Step 2. Generate waymo train/val/test infos and save as pickle file.
Step 3. Generate waymo ground truth database (point clouds within
each 3D bounding box) for data augmentation in training.
Steps 1 and 2 will be done in Waymo2KITTI, and step 3 will be done in
GTDatabaseCreater.
Args: Args:
root_path (str): Path of dataset root. root_path (str): Path of dataset root.
...@@ -180,44 +193,55 @@ def waymo_data_prep(root_path, ...@@ -180,44 +193,55 @@ def waymo_data_prep(root_path,
out_dir (str): Output directory of the generated info file. out_dir (str): Output directory of the generated info file.
workers (int): Number of threads to be used. workers (int): Number of threads to be used.
max_sweeps (int, optional): Number of input consecutive frames. max_sweeps (int, optional): Number of input consecutive frames.
Default: 5. Here we store pose information of these frames Default to 10. Here we store ego2global information of these
for later use. frames for later use.
only_gt_database (bool, optional): Whether to only generate ground
truth database. Default to False.
save_senor_data (bool, optional): Whether to skip saving
image and lidar. Default to False.
skip_cam_instances_infos (bool, optional): Whether to skip
gathering cam_instances infos in Step 2. Default to False.
""" """
from tools.dataset_converters import waymo_converter as waymo from tools.dataset_converters import waymo_converter as waymo
splits = [ if version == 'v1.4':
'training', 'validation', 'testing', 'testing_3d_camera_only_detection' splits = [
] 'training', 'validation', 'testing',
for i, split in enumerate(splits): 'testing_3d_camera_only_detection'
load_dir = osp.join(root_path, 'waymo_format', split) ]
if split == 'validation': elif version == 'v1.4-mini':
save_dir = osp.join(out_dir, 'kitti_format', 'training') splits = ['training', 'validation']
else: else:
save_dir = osp.join(out_dir, 'kitti_format', split) raise NotImplementedError(f'Unsupported Waymo version {version}!')
converter = waymo.Waymo2KITTI(
load_dir,
save_dir,
prefix=str(i),
workers=workers,
test_mode=(split
in ['testing', 'testing_3d_camera_only_detection']))
converter.convert()
from tools.dataset_converters.waymo_converter import \
create_ImageSets_img_ids
create_ImageSets_img_ids(osp.join(out_dir, 'kitti_format'), splits)
# Generate waymo infos
out_dir = osp.join(out_dir, 'kitti_format') out_dir = osp.join(out_dir, 'kitti_format')
kitti.create_waymo_info_file(
out_dir, info_prefix, max_sweeps=max_sweeps, workers=workers) if not only_gt_database:
info_train_path = osp.join(out_dir, f'{info_prefix}_infos_train.pkl') for i, split in enumerate(splits):
info_val_path = osp.join(out_dir, f'{info_prefix}_infos_val.pkl') load_dir = osp.join(root_path, 'waymo_format', split)
info_trainval_path = osp.join(out_dir, f'{info_prefix}_infos_trainval.pkl') if split == 'validation':
info_test_path = osp.join(out_dir, f'{info_prefix}_infos_test.pkl') save_dir = osp.join(out_dir, 'training')
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=info_train_path) else:
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=info_val_path) save_dir = osp.join(out_dir, split)
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=info_trainval_path) converter = waymo.Waymo2KITTI(
update_pkl_infos('waymo', out_dir=out_dir, pkl_path=info_test_path) load_dir,
save_dir,
prefix=str(i),
workers=workers,
test_mode=(split
in ['testing', 'testing_3d_camera_only_detection']),
info_prefix=info_prefix,
max_sweeps=max_sweeps,
split=split,
save_senor_data=save_senor_data,
save_cam_instances=not skip_cam_instances_infos)
converter.convert()
if split == 'validation':
converter.merge_trainval_infos()
from tools.dataset_converters.waymo_converter import \
create_ImageSets_img_ids
create_ImageSets_img_ids(out_dir, splits)
GTDatabaseCreater( GTDatabaseCreater(
'WaymoDataset', 'WaymoDataset',
out_dir, out_dir,
...@@ -227,6 +251,8 @@ def waymo_data_prep(root_path, ...@@ -227,6 +251,8 @@ def waymo_data_prep(root_path,
with_mask=False, with_mask=False,
num_worker=workers).create() num_worker=workers).create()
print_log('Successfully preparing Waymo Open Dataset')
def semantickitti_data_prep(info_prefix, out_dir): def semantickitti_data_prep(info_prefix, out_dir):
"""Prepare the info file for SemanticKITTI dataset. """Prepare the info file for SemanticKITTI dataset.
...@@ -274,12 +300,23 @@ parser.add_argument( ...@@ -274,12 +300,23 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
'--only-gt-database', '--only-gt-database',
action='store_true', action='store_true',
help='Whether to only generate ground truth database.') help='''Whether to only generate ground truth database.
Only used when dataset is NuScenes or Waymo!''')
parser.add_argument(
'--skip-cam_instances-infos',
action='store_true',
help='''Whether to skip gathering cam_instances infos.
Only used when dataset is Waymo!''')
parser.add_argument(
'--skip-saving-sensor-data',
action='store_true',
help='''Whether to skip saving image and lidar.
Only used when dataset is Waymo!''')
args = parser.parse_args() args = parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
from mmdet3d.utils import register_all_modules from mmengine.registry import init_default_scope
register_all_modules() init_default_scope('mmdet3d')
if args.dataset == 'kitti': if args.dataset == 'kitti':
if args.only_gt_database: if args.only_gt_database:
...@@ -334,6 +371,17 @@ if __name__ == '__main__': ...@@ -334,6 +371,17 @@ if __name__ == '__main__':
dataset_name='NuScenesDataset', dataset_name='NuScenesDataset',
out_dir=args.out_dir, out_dir=args.out_dir,
max_sweeps=args.max_sweeps) max_sweeps=args.max_sweeps)
elif args.dataset == 'waymo':
waymo_data_prep(
root_path=args.root_path,
info_prefix=args.extra_tag,
version=args.version,
out_dir=args.out_dir,
workers=args.workers,
max_sweeps=args.max_sweeps,
only_gt_database=args.only_gt_database,
save_senor_data=not args.skip_saving_sensor_data,
skip_cam_instances_infos=args.skip_cam_instances_infos)
elif args.dataset == 'lyft': elif args.dataset == 'lyft':
train_version = f'{args.version}-train' train_version = f'{args.version}-train'
lyft_data_prep( lyft_data_prep(
...@@ -347,14 +395,6 @@ if __name__ == '__main__': ...@@ -347,14 +395,6 @@ if __name__ == '__main__':
info_prefix=args.extra_tag, info_prefix=args.extra_tag,
version=test_version, version=test_version,
max_sweeps=args.max_sweeps) max_sweeps=args.max_sweeps)
elif args.dataset == 'waymo':
waymo_data_prep(
root_path=args.root_path,
info_prefix=args.extra_tag,
version=args.version,
out_dir=args.out_dir,
workers=args.workers,
max_sweeps=args.max_sweeps)
elif args.dataset == 'scannet': elif args.dataset == 'scannet':
scannet_data_prep( scannet_data_prep(
root_path=args.root_path, root_path=args.root_path,
......
...@@ -6,10 +6,11 @@ export PYTHONPATH=`pwd`:$PYTHONPATH ...@@ -6,10 +6,11 @@ export PYTHONPATH=`pwd`:$PYTHONPATH
PARTITION=$1 PARTITION=$1
JOB_NAME=$2 JOB_NAME=$2
DATASET=$3 DATASET=$3
WORKERS=$4
GPUS=${GPUS:-1} GPUS=${GPUS:-1}
GPUS_PER_NODE=${GPUS_PER_NODE:-1} GPUS_PER_NODE=${GPUS_PER_NODE:-1}
SRUN_ARGS=${SRUN_ARGS:-""} SRUN_ARGS=${SRUN_ARGS:-""}
JOB_NAME=create_data PY_ARGS=${@:5}
srun -p ${PARTITION} \ srun -p ${PARTITION} \
--job-name=${JOB_NAME} \ --job-name=${JOB_NAME} \
...@@ -21,4 +22,6 @@ srun -p ${PARTITION} \ ...@@ -21,4 +22,6 @@ srun -p ${PARTITION} \
python -u tools/create_data.py ${DATASET} \ python -u tools/create_data.py ${DATASET} \
--root-path ./data/${DATASET} \ --root-path ./data/${DATASET} \
--out-dir ./data/${DATASET} \ --out-dir ./data/${DATASET} \
--extra-tag ${DATASET} --workers ${WORKERS} \
--extra-tag ${DATASET} \
${PY_ARGS}
...@@ -7,7 +7,7 @@ import mmengine ...@@ -7,7 +7,7 @@ import mmengine
import numpy as np import numpy as np
from mmcv.ops import roi_align from mmcv.ops import roi_align
from mmdet.evaluation import bbox_overlaps from mmdet.evaluation import bbox_overlaps
from mmengine import track_iter_progress from mmengine import print_log, track_iter_progress
from pycocotools import mask as maskUtils from pycocotools import mask as maskUtils
from pycocotools.coco import COCO from pycocotools.coco import COCO
...@@ -504,7 +504,9 @@ class GTDatabaseCreater: ...@@ -504,7 +504,9 @@ class GTDatabaseCreater:
return single_db_infos return single_db_infos
def create(self): def create(self):
print(f'Create GT Database of {self.dataset_class_name}') print_log(
f'Create GT Database of {self.dataset_class_name}',
logger='current')
dataset_cfg = dict( dataset_cfg = dict(
type=self.dataset_class_name, type=self.dataset_class_name,
data_root=self.data_path, data_root=self.data_path,
...@@ -610,12 +612,19 @@ class GTDatabaseCreater: ...@@ -610,12 +612,19 @@ class GTDatabaseCreater:
input_dict['box_mode_3d'] = self.dataset.box_mode_3d input_dict['box_mode_3d'] = self.dataset.box_mode_3d
return input_dict return input_dict
multi_db_infos = mmengine.track_parallel_progress( if self.num_worker == 0:
self.create_single, multi_db_infos = mmengine.track_progress(
((loop_dataset(i) self.create_single,
for i in range(len(self.dataset))), len(self.dataset)), ((loop_dataset(i)
self.num_worker) for i in range(len(self.dataset))), len(self.dataset)))
print('Make global unique group id') else:
multi_db_infos = mmengine.track_parallel_progress(
self.create_single,
((loop_dataset(i)
for i in range(len(self.dataset))), len(self.dataset)),
self.num_worker,
chunksize=1000)
print_log('Make global unique group id', logger='current')
group_counter_offset = 0 group_counter_offset = 0
all_db_infos = dict() all_db_infos = dict()
for single_db_infos in track_iter_progress(multi_db_infos): for single_db_infos in track_iter_progress(multi_db_infos):
...@@ -630,7 +639,8 @@ class GTDatabaseCreater: ...@@ -630,7 +639,8 @@ class GTDatabaseCreater:
group_counter_offset += (group_id + 1) group_counter_offset += (group_id + 1)
for k, v in all_db_infos.items(): for k, v in all_db_infos.items():
print(f'load {len(v)} {k} database infos') print_log(f'load {len(v)} {k} database infos', logger='current')
print_log(f'Saving GT database infos into {self.db_info_save_path}')
with open(self.db_info_save_path, 'wb') as f: with open(self.db_info_save_path, 'wb') as f:
pickle.dump(all_db_infos, f) pickle.dump(all_db_infos, f)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment