Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
333536f6
Unverified
Commit
333536f6
authored
Apr 06, 2022
by
Wenwei Zhang
Committed by
GitHub
Apr 06, 2022
Browse files
Release v1.0.0rc1
parents
9c7270d0
f747daab
Changes
219
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
804 additions
and
40 deletions
+804
-40
docs/en/tutorials/index.rst
docs/en/tutorials/index.rst
+1
-0
docs/zh_cn/1_exist_data_model.md
docs/zh_cn/1_exist_data_model.md
+3
-2
docs/zh_cn/getting_started.md
docs/zh_cn/getting_started.md
+27
-8
docs/zh_cn/model_zoo.md
docs/zh_cn/model_zoo.md
+1
-1
docs/zh_cn/tutorials/backends_support.md
docs/zh_cn/tutorials/backends_support.md
+151
-0
docs/zh_cn/tutorials/customize_dataset.md
docs/zh_cn/tutorials/customize_dataset.md
+2
-2
docs/zh_cn/tutorials/index.rst
docs/zh_cn/tutorials/index.rst
+1
-0
mmdet3d/__init__.py
mmdet3d/__init__.py
+1
-1
mmdet3d/core/bbox/structures/base_box3d.py
mmdet3d/core/bbox/structures/base_box3d.py
+4
-5
mmdet3d/core/evaluation/__init__.py
mmdet3d/core/evaluation/__init__.py
+2
-1
mmdet3d/core/evaluation/instance_seg_eval.py
mmdet3d/core/evaluation/instance_seg_eval.py
+128
-0
mmdet3d/core/evaluation/scannet_utils/evaluate_semantic_instance.py
...re/evaluation/scannet_utils/evaluate_semantic_instance.py
+347
-0
mmdet3d/core/evaluation/scannet_utils/util_3d.py
mmdet3d/core/evaluation/scannet_utils/util_3d.py
+84
-0
mmdet3d/core/points/base_points.py
mmdet3d/core/points/base_points.py
+1
-1
mmdet3d/core/post_processing/box3d_nms.py
mmdet3d/core/post_processing/box3d_nms.py
+2
-2
mmdet3d/core/post_processing/merge_augs.py
mmdet3d/core/post_processing/merge_augs.py
+2
-1
mmdet3d/datasets/__init__.py
mmdet3d/datasets/__init__.py
+8
-7
mmdet3d/datasets/custom_3d.py
mmdet3d/datasets/custom_3d.py
+19
-4
mmdet3d/datasets/custom_3d_seg.py
mmdet3d/datasets/custom_3d_seg.py
+16
-3
mmdet3d/datasets/kitti_dataset.py
mmdet3d/datasets/kitti_dataset.py
+4
-2
No files found.
docs/en/tutorials/index.rst
View file @
333536f6
...
@@ -7,3 +7,4 @@
...
@@ -7,3 +7,4 @@
customize_models.md
customize_models.md
customize_runtime.md
customize_runtime.md
coord_sys_tutorial.md
coord_sys_tutorial.md
backends_support.md
docs/zh_cn/1_exist_data_model.md
View file @
333536f6
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
### 在标准数据集上测试已有模型
### 在标准数据集上测试已有模型
-
单显卡
-
单显卡
-
CPU
-
单节点多显卡
-
单节点多显卡
-
多节点
-
多节点
...
@@ -65,7 +66,7 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [-
...
@@ -65,7 +66,7 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [-
--eval
mAP
--eval
mAP
```
```
4.
使用8块显卡测试 SECOND,计算 mAP
4.
使用8块显卡
在 KITTI 数据集上
测试 SECOND,计算 mAP
```
shell
```
shell
./tools/slurm_test.sh
${
PARTITION
}
${
JOB_NAME
}
configs/second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py
\
./tools/slurm_test.sh
${
PARTITION
}
${
JOB_NAME
}
configs/second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py
\
...
@@ -83,7 +84,7 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [-
...
@@ -83,7 +84,7 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [-
生成的结果会保存在
`./pointpillars_nuscenes_results`
目录。
生成的结果会保存在
`./pointpillars_nuscenes_results`
目录。
6.
使用8块显卡在 KITTI 数据集上测试
PointPillars
,生成提交给官方评测服务器的
json
文件
6.
使用8块显卡在 KITTI 数据集上测试
SECOND
,生成提交给官方评测服务器的
txt
文件
```
shell
```
shell
./tools/slurm_test.sh
${
PARTITION
}
${
JOB_NAME
}
configs/second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py
\
./tools/slurm_test.sh
${
PARTITION
}
${
JOB_NAME
}
configs/second/hv_second_secfpn_6x8_80e_kitti-3d-3class.py
\
...
...
docs/zh_cn/getting_started.md
View file @
333536f6
# 依赖
# 依赖
-
Linux
or
macOS (Windows
is not currently officially supported
)
-
Linux
或者
macOS (
实验性支持
Windows)
-
Python 3.6+
-
Python 3.6+
-
PyTorch 1.3+
-
PyTorch 1.3+
-
CUDA 9.2+ (
If you build PyTorch from source, CUDA 9.0 is also compatible
)
-
CUDA 9.2+ (
如果你从源码编译 PyTorch, CUDA 9.0 也是兼容的。
)
-
GCC 5+
-
GCC 5+
-
[
MMCV
](
https://mmcv.readthedocs.io/en/latest/#installation
)
-
[
MMCV
](
https://mmcv.readthedocs.io/en/latest/#installation
)
| MMDetection3D
version
| MMDetection
version
| MMSegmentation
version
| MMCV
version
|
|
MMDetection3D
版本
|
MMDetection
版本
|
MMSegmentation
版本
|
MMCV
版本
|
|:-------------------:|:-------------------:|:-------------------:|:-------------------:|
|:-------------------:|:-------------------:|:-------------------:|:-------------------:|
| master | mmdet>=2.19.0,
<
=3.0.0|
mmseg
>
=0.20.0,
<
=1.0.0
|
mmcv-full
>
=1.3.17, <=1.5.0|
| master | mmdet>=2.19.0,
<
=3.0.0|
mmseg
>
=0.20.0,
<
=1.0.0
|
mmcv-full
>
=1.4.8, <=1.5.0|
| v1.0.0rc1 | mmdet>=2.19.0,
<
=3.0.0|
mmseg
>
=0.20.0,
<
=1.0.0
|
mmcv-full
>
=1.4.8, <=1.5.0|
| v1.0.0rc0 | mmdet>=2.19.0,
<
=3.0.0|
mmseg
>
=0.20.0,
<
=1.0.0
|
mmcv-full
>
=1.3.17, <=1.5.0|
| v1.0.0rc0 | mmdet>=2.19.0,
<
=3.0.0|
mmseg
>
=0.20.0,
<
=1.0.0
|
mmcv-full
>
=1.3.17, <=1.5.0|
| 0.18.1 | mmdet>=2.19.0,
<
=3.0.0|
mmseg
>
=0.20.0,
<
=1.0.0
|
mmcv-full
>
=1.3.17, <=1.5.0|
| 0.18.1 | mmdet>=2.19.0,
<
=3.0.0|
mmseg
>
=0.20.0,
<
=1.0.0
|
mmcv-full
>
=1.3.17, <=1.5.0|
| 0.18.0 | mmdet>=2.19.0,
<
=3.0.0|
mmseg
>
=0.20.0,
<
=1.0.0
|
mmcv-full
>
=1.3.17, <=1.5.0|
| 0.18.0 | mmdet>=2.19.0,
<
=3.0.0|
mmseg
>
=0.20.0,
<
=1.0.0
|
mmcv-full
>
=1.3.17, <=1.5.0|
...
@@ -34,6 +35,24 @@
...
@@ -34,6 +35,24 @@
## MMdetection3D 安装流程
## MMdetection3D 安装流程
### 快速安装脚本
如果你已经成功安装 CUDA 11.0,那么你可以使用这个快速安装命令进行 MMDetection3D 的安装。 否则,则参考下一小节的详细安装流程。
```
shell
conda create
-n
open-mmlab
python
=
3.7
pytorch
=
1.9
cudatoolkit
=
11.0 torchvision
-c
pytorch
-y
conda activate open-mmlab
pip3
install
openmim
mim
install
mmcv-full
mim
install
mmdet
mim
install
mmsegmentation
git clone https://github.com/open-mmlab/mmdetection3d.git
cd
mmdetection3d
pip3
install
-e
.
```
### 详细安装流程
**a. 使用 conda 新建虚拟环境,并进入该虚拟环境。**
**a. 使用 conda 新建虚拟环境,并进入该虚拟环境。**
```
shell
```
shell
...
@@ -102,7 +121,7 @@ pip install mmcv-full
...
@@ -102,7 +121,7 @@ pip install mmcv-full
**d. 安装 [MMDetection](https://github.com/open-mmlab/mmdetection).**
**d. 安装 [MMDetection](https://github.com/open-mmlab/mmdetection).**
```
shell
```
shell
pip
install
mmdet
==
2.14.0
pip
install
mmdet
```
```
同时,如果你想修改这部分的代码,也可以通过以下命令从源码编译 MMDetection:
同时,如果你想修改这部分的代码,也可以通过以下命令从源码编译 MMDetection:
...
@@ -110,7 +129,7 @@ pip install mmdet==2.14.0
...
@@ -110,7 +129,7 @@ pip install mmdet==2.14.0
```
shell
```
shell
git clone https://github.com/open-mmlab/mmdetection.git
git clone https://github.com/open-mmlab/mmdetection.git
cd
mmdetection
cd
mmdetection
git checkout v2.1
4
.0
# 转到 v2.1
4
.0 分支
git checkout v2.1
9
.0
# 转到 v2.1
9
.0 分支
pip
install
-r
requirements/build.txt
pip
install
-r
requirements/build.txt
pip
install
-v
-e
.
# or "python setup.py develop"
pip
install
-v
-e
.
# or "python setup.py develop"
```
```
...
@@ -118,14 +137,14 @@ pip install -v -e . # or "python setup.py develop"
...
@@ -118,14 +137,14 @@ pip install -v -e . # or "python setup.py develop"
**e. 安装 [MMSegmentation](https://github.com/open-mmlab/mmsegmentation).**
**e. 安装 [MMSegmentation](https://github.com/open-mmlab/mmsegmentation).**
```
shell
```
shell
pip
install
mmsegmentation
==
0.14.1
pip
install
mmsegmentation
```
```
同时,如果你想修改这部分的代码,也可以通过以下命令从源码编译 MMSegmentation:
同时,如果你想修改这部分的代码,也可以通过以下命令从源码编译 MMSegmentation:
```
shell
```
shell
git clone https://github.com/open-mmlab/mmsegmentation.git
git clone https://github.com/open-mmlab/mmsegmentation.git
cd
mmsegmentation
cd
mmsegmentation
git checkout v0.
14.1
# switch to v0.
14.1
branch
git checkout v0.
20.0
# switch to v0.
20.0
branch
pip
install
-e
.
# or "python setup.py develop"
pip
install
-e
.
# or "python setup.py develop"
```
```
...
...
docs/zh_cn/model_zoo.md
View file @
333536f6
...
@@ -102,4 +102,4 @@
...
@@ -102,4 +102,4 @@
### Mixed Precision (FP16) Training
### Mixed Precision (FP16) Training
细节请参考 [Mixed Precision (FP16) Training
]
在 PointPillars 训练的样例
(https://github.com/open-mmlab/mmdetection3d/tree/v1.0.0.dev0/configs/pointpillars/hv_pointpillars_fpn_sbn-all_fp16_2x8_2x_nus-3d.py).
细节请参考
[
Mixed Precision (FP16) Training 在 PointPillars 训练的样例
]
(
https://github.com/open-mmlab/mmdetection3d/tree/v1.0.0.dev0/configs/pointpillars/hv_pointpillars_fpn_sbn-all_fp16_2x8_2x_nus-3d.py
)
.
docs/zh_cn/tutorials/backends_support.md
0 → 100644
View file @
333536f6
# Tutorial 7: 后端支持
我们支持不同的文件客户端后端:磁盘、Ceph 和 LMDB 等。下面是修改配置使之从 Ceph 加载和保存数据的示例。
## 从 Ceph 读取数据和标注文件
我们支持从 Ceph 加载数据和生成的标注信息文件(pkl 和 json):
```
python
# set file client backends as Ceph
file_client_args
=
dict
(
backend
=
'petrel'
,
path_mapping
=
dict
({
'./data/nuscenes/'
:
's3://openmmlab/datasets/detection3d/nuscenes/'
,
# replace the path with your data path on Ceph
'data/nuscenes/'
:
's3://openmmlab/datasets/detection3d/nuscenes/'
# replace the path with your data path on Ceph
}))
db_sampler
=
dict
(
data_root
=
data_root
,
info_path
=
data_root
+
'kitti_dbinfos_train.pkl'
,
rate
=
1.0
,
prepare
=
dict
(
filter_by_difficulty
=
[
-
1
],
filter_by_min_points
=
dict
(
Car
=
5
)),
sample_groups
=
dict
(
Car
=
15
),
classes
=
class_names
,
# set file client for points loader to load training data
points_loader
=
dict
(
type
=
'LoadPointsFromFile'
,
coord_type
=
'LIDAR'
,
load_dim
=
4
,
use_dim
=
4
,
file_client_args
=
file_client_args
),
# set file client for data base sampler to load db info file
file_client_args
=
file_client_args
)
train_pipeline
=
[
# set file client for loading training data
dict
(
type
=
'LoadPointsFromFile'
,
coord_type
=
'LIDAR'
,
load_dim
=
4
,
use_dim
=
4
,
file_client_args
=
file_client_args
),
# set file client for loading training data annotations
dict
(
type
=
'LoadAnnotations3D'
,
with_bbox_3d
=
True
,
with_label_3d
=
True
,
file_client_args
=
file_client_args
),
dict
(
type
=
'ObjectSample'
,
db_sampler
=
db_sampler
),
dict
(
type
=
'ObjectNoise'
,
num_try
=
100
,
translation_std
=
[
0.25
,
0.25
,
0.25
],
global_rot_range
=
[
0.0
,
0.0
],
rot_range
=
[
-
0.15707963267
,
0.15707963267
]),
dict
(
type
=
'RandomFlip3D'
,
flip_ratio_bev_horizontal
=
0.5
),
dict
(
type
=
'GlobalRotScaleTrans'
,
rot_range
=
[
-
0.78539816
,
0.78539816
],
scale_ratio_range
=
[
0.95
,
1.05
]),
dict
(
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
dict
(
type
=
'ObjectRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
dict
(
type
=
'PointShuffle'
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'gt_bboxes_3d'
,
'gt_labels_3d'
])
]
test_pipeline
=
[
# set file client for loading validation/testing data
dict
(
type
=
'LoadPointsFromFile'
,
coord_type
=
'LIDAR'
,
load_dim
=
4
,
use_dim
=
4
,
file_client_args
=
file_client_args
),
dict
(
type
=
'MultiScaleFlipAug3D'
,
img_scale
=
(
1333
,
800
),
pts_scale_ratio
=
1
,
flip
=
False
,
transforms
=
[
dict
(
type
=
'GlobalRotScaleTrans'
,
rot_range
=
[
0
,
0
],
scale_ratio_range
=
[
1.
,
1.
],
translation_std
=
[
0
,
0
,
0
]),
dict
(
type
=
'RandomFlip3D'
),
dict
(
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
,
with_label
=
False
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
])
])
]
data
=
dict
(
# set file client for loading training info files (.pkl)
train
=
dict
(
type
=
'RepeatDataset'
,
times
=
2
,
dataset
=
dict
(
pipeline
=
train_pipeline
,
classes
=
class_names
,
file_client_args
=
file_client_args
)),
# set file client for loading validation info files (.pkl)
val
=
dict
(
pipeline
=
test_pipeline
,
classes
=
class_names
,
file_client_args
=
file_client_args
),
# set file client for loading testing info files (.pkl)
test
=
dict
(
pipeline
=
test_pipeline
,
classes
=
class_names
,
file_client_args
=
file_client_args
))
```
## 从 Ceph 读取预训练模型
```
python
model
=
dict
(
pts_backbone
=
dict
(
_delete_
=
True
,
type
=
'NoStemRegNet'
,
arch
=
'regnetx_1.6gf'
,
init_cfg
=
dict
(
type
=
'Pretrained'
,
checkpoint
=
's3://openmmlab/checkpoints/mmdetection3d/regnetx_1.6gf'
),
# replace the path with your pretrained model path on Ceph
...
```
## 从 Ceph 读取模型权重文件
```
python
# replace the path with your checkpoint path on Ceph
load_from
=
's3://openmmlab/checkpoints/mmdetection3d/v0.1.0_models/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-car/hv_pointpillars_secfpn_6x8_160e_kitti-3d-car_20200620_230614-77663cd6.pth.pth'
resume_from
=
None
workflow
=
[(
'train'
,
1
)]
```
## 保存模型权重文件至 Ceph
```
python
# checkpoint saving
# replace the path with your checkpoint saving path on Ceph
checkpoint_config
=
dict
(
interval
=
1
,
max_keep_ckpts
=
2
,
out_dir
=
's3://openmmlab/mmdetection3d'
)
```
## EvalHook 保存最优模型权重文件至 Ceph
```
python
# replace the path with your checkpoint saving path on Ceph
evaluation
=
dict
(
interval
=
1
,
save_best
=
'bbox'
,
out_dir
=
's3://openmmlab/mmdetection3d'
)
```
## 训练日志保存至 Ceph
训练后的训练日志会备份到指定的 Ceph 路径。
```
python
log_config
=
dict
(
interval
=
50
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
,
out_dir
=
's3://openmmlab/mmdetection3d'
),
])
```
您还可以通过设置
`keep_local = False`
备份到指定的 Ceph 路径后删除本地训练日志。
```
python
log_config
=
dict
(
interval
=
50
,
hooks
=
[
dict
(
type
=
'TextLoggerHook'
,
out_dir
=
's3://openmmlab/mmdetection3d'', keep_local=False),
])
```
docs/zh_cn/tutorials/customize_dataset.md
View file @
333536f6
...
@@ -130,10 +130,10 @@ class MyDataset(Custom3DDataset):
...
@@ -130,10 +130,10 @@ class MyDataset(Custom3DDataset):
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
if
info
[
'annos'
][
'gt_num'
]
!=
0
:
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
].
astype
(
gt_bboxes_3d
=
info
[
'annos'
][
'gt_boxes_upright_depth'
].
astype
(
np
.
float32
)
# k, 6
np
.
float32
)
# k, 6
gt_labels_3d
=
info
[
'annos'
][
'class'
].
astype
(
np
.
long
)
gt_labels_3d
=
info
[
'annos'
][
'class'
].
astype
(
np
.
int64
)
else
:
else
:
gt_bboxes_3d
=
np
.
zeros
((
0
,
6
),
dtype
=
np
.
float32
)
gt_bboxes_3d
=
np
.
zeros
((
0
,
6
),
dtype
=
np
.
float32
)
gt_labels_3d
=
np
.
zeros
((
0
,
),
dtype
=
np
.
long
)
gt_labels_3d
=
np
.
zeros
((
0
,
),
dtype
=
np
.
int64
)
# 转换为目标标注框的结构
# 转换为目标标注框的结构
gt_bboxes_3d
=
DepthInstance3DBoxes
(
gt_bboxes_3d
=
DepthInstance3DBoxes
(
...
...
docs/zh_cn/tutorials/index.rst
View file @
333536f6
...
@@ -7,3 +7,4 @@
...
@@ -7,3 +7,4 @@
customize_models.md
customize_models.md
customize_runtime.md
customize_runtime.md
coord_sys_tutorial.md
coord_sys_tutorial.md
backends_support.md
mmdet3d/__init__.py
View file @
333536f6
...
@@ -18,7 +18,7 @@ def digit_version(version_str):
...
@@ -18,7 +18,7 @@ def digit_version(version_str):
return
digit_version
return
digit_version
mmcv_minimum_version
=
'1.
3.17
'
mmcv_minimum_version
=
'1.
4.8
'
mmcv_maximum_version
=
'1.5.0'
mmcv_maximum_version
=
'1.5.0'
mmcv_version
=
digit_version
(
mmcv
.
__version__
)
mmcv_version
=
digit_version
(
mmcv
.
__version__
)
...
...
mmdet3d/core/bbox/structures/base_box3d.py
View file @
333536f6
...
@@ -4,9 +4,9 @@ from abc import abstractmethod
...
@@ -4,9 +4,9 @@ from abc import abstractmethod
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mmcv._ext
import
iou3d_boxes_overlap_bev_forward
as
boxes_overlap_bev_gpu
from
mmcv.ops
import
points_in_boxes_all
,
points_in_boxes_part
from
mmdet3d.ops
import
points_in_boxes_all
,
points_in_boxes_part
from
mmdet3d.ops.iou3d
import
iou3d_cuda
from
.utils
import
limit_period
,
xywhr2xyxyr
from
.utils
import
limit_period
,
xywhr2xyxyr
...
@@ -471,9 +471,8 @@ class BaseInstance3DBoxes(object):
...
@@ -471,9 +471,8 @@ class BaseInstance3DBoxes(object):
# bev overlap
# bev overlap
overlaps_bev
=
boxes1_bev
.
new_zeros
(
overlaps_bev
=
boxes1_bev
.
new_zeros
(
(
boxes1_bev
.
shape
[
0
],
boxes2_bev
.
shape
[
0
])).
cuda
()
# (N, M)
(
boxes1_bev
.
shape
[
0
],
boxes2_bev
.
shape
[
0
])).
cuda
()
# (N, M)
iou3d_cuda
.
boxes_overlap_bev_gpu
(
boxes1_bev
.
contiguous
().
cuda
(),
boxes_overlap_bev_gpu
(
boxes1_bev
.
contiguous
().
cuda
(),
boxes2_bev
.
contiguous
().
cuda
(),
boxes2_bev
.
contiguous
().
cuda
(),
overlaps_bev
)
overlaps_bev
)
# 3d overlaps
# 3d overlaps
overlaps_3d
=
overlaps_bev
.
to
(
boxes1
.
device
)
*
overlaps_h
overlaps_3d
=
overlaps_bev
.
to
(
boxes1
.
device
)
*
overlaps_h
...
...
mmdet3d/core/evaluation/__init__.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
.indoor_eval
import
indoor_eval
from
.indoor_eval
import
indoor_eval
from
.instance_seg_eval
import
instance_seg_eval
from
.kitti_utils
import
kitti_eval
,
kitti_eval_coco_style
from
.kitti_utils
import
kitti_eval
,
kitti_eval_coco_style
from
.lyft_eval
import
lyft_eval
from
.lyft_eval
import
lyft_eval
from
.seg_eval
import
seg_eval
from
.seg_eval
import
seg_eval
__all__
=
[
__all__
=
[
'kitti_eval_coco_style'
,
'kitti_eval'
,
'indoor_eval'
,
'lyft_eval'
,
'kitti_eval_coco_style'
,
'kitti_eval'
,
'indoor_eval'
,
'lyft_eval'
,
'seg_eval'
'seg_eval'
,
'instance_seg_eval'
]
]
mmdet3d/core/evaluation/instance_seg_eval.py
0 → 100644
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
from
mmcv.utils
import
print_log
from
terminaltables
import
AsciiTable
from
.scannet_utils.evaluate_semantic_instance
import
scannet_eval
def
aggregate_predictions
(
masks
,
labels
,
scores
,
valid_class_ids
):
"""Maps predictions to ScanNet evaluator format.
Args:
masks (list[torch.Tensor]): Per scene predicted instance masks.
labels (list[torch.Tensor]): Per scene predicted instance labels.
scores (list[torch.Tensor]): Per scene predicted instance scores.
valid_class_ids (tuple[int]): Ids of valid categories.
Returns:
list[dict]: Per scene aggregated predictions.
"""
infos
=
[]
for
id
,
(
mask
,
label
,
score
)
in
enumerate
(
zip
(
masks
,
labels
,
scores
)):
mask
=
mask
.
clone
().
numpy
()
label
=
label
.
clone
().
numpy
()
score
=
score
.
clone
().
numpy
()
info
=
dict
()
n_instances
=
mask
.
max
()
+
1
for
i
in
range
(
n_instances
):
# match pred_instance['filename'] from assign_instances_for_scan
file_name
=
f
'
{
id
}
_
{
i
}
'
info
[
file_name
]
=
dict
()
info
[
file_name
][
'mask'
]
=
(
mask
==
i
).
astype
(
np
.
int
)
info
[
file_name
][
'label_id'
]
=
valid_class_ids
[
label
[
i
]]
info
[
file_name
][
'conf'
]
=
score
[
i
]
infos
.
append
(
info
)
return
infos
def
rename_gt
(
gt_semantic_masks
,
gt_instance_masks
,
valid_class_ids
):
"""Maps gt instance and semantic masks to instance masks for ScanNet
evaluator.
Args:
gt_semantic_masks (list[torch.Tensor]): Per scene gt semantic masks.
gt_instance_masks (list[torch.Tensor]): Per scene gt instance masks.
valid_class_ids (tuple[int]): Ids of valid categories.
Returns:
list[np.array]: Per scene instance masks.
"""
renamed_instance_masks
=
[]
for
semantic_mask
,
instance_mask
in
zip
(
gt_semantic_masks
,
gt_instance_masks
):
semantic_mask
=
semantic_mask
.
clone
().
numpy
()
instance_mask
=
instance_mask
.
clone
().
numpy
()
unique
=
np
.
unique
(
instance_mask
)
assert
len
(
unique
)
<
1000
for
i
in
unique
:
semantic_instance
=
semantic_mask
[
instance_mask
==
i
]
semantic_unique
=
np
.
unique
(
semantic_instance
)
assert
len
(
semantic_unique
)
==
1
if
semantic_unique
[
0
]
<
len
(
valid_class_ids
):
instance_mask
[
instance_mask
==
i
]
=
1000
*
valid_class_ids
[
semantic_unique
[
0
]]
+
i
renamed_instance_masks
.
append
(
instance_mask
)
return
renamed_instance_masks
def
instance_seg_eval
(
gt_semantic_masks
,
gt_instance_masks
,
pred_instance_masks
,
pred_instance_labels
,
pred_instance_scores
,
valid_class_ids
,
class_labels
,
options
=
None
,
logger
=
None
):
"""Instance Segmentation Evaluation.
Evaluate the result of the instance segmentation.
Args:
gt_semantic_masks (list[torch.Tensor]): Ground truth semantic masks.
gt_instance_masks (list[torch.Tensor]): Ground truth instance masks.
pred_instance_masks (list[torch.Tensor]): Predicted instance masks.
pred_instance_labels (list[torch.Tensor]): Predicted instance labels.
pred_instance_scores (list[torch.Tensor]): Predicted instance labels.
valid_class_ids (tuple[int]): Ids of valid categories.
class_labels (tuple[str]): Names of valid categories.
options (dict, optional): Additional options. Keys may contain:
`overlaps`, `min_region_sizes`, `distance_threshes`,
`distance_confs`. Default: None.
logger (logging.Logger | str, optional): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None.
Returns:
dict[str, float]: Dict of results.
"""
assert
len
(
valid_class_ids
)
==
len
(
class_labels
)
id_to_label
=
{
valid_class_ids
[
i
]:
class_labels
[
i
]
for
i
in
range
(
len
(
valid_class_ids
))
}
preds
=
aggregate_predictions
(
masks
=
pred_instance_masks
,
labels
=
pred_instance_labels
,
scores
=
pred_instance_scores
,
valid_class_ids
=
valid_class_ids
)
gts
=
rename_gt
(
gt_semantic_masks
,
gt_instance_masks
,
valid_class_ids
)
metrics
=
scannet_eval
(
preds
=
preds
,
gts
=
gts
,
options
=
options
,
valid_class_ids
=
valid_class_ids
,
class_labels
=
class_labels
,
id_to_label
=
id_to_label
)
header
=
[
'classes'
,
'AP_0.25'
,
'AP_0.50'
,
'AP'
]
rows
=
[]
for
label
,
data
in
metrics
[
'classes'
].
items
():
aps
=
[
data
[
'ap25%'
],
data
[
'ap50%'
],
data
[
'ap'
]]
rows
.
append
([
label
]
+
[
f
'
{
ap
:.
4
f
}
'
for
ap
in
aps
])
aps
=
metrics
[
'all_ap_25%'
],
metrics
[
'all_ap_50%'
],
metrics
[
'all_ap'
]
footer
=
[
'Overall'
]
+
[
f
'
{
ap
:.
4
f
}
'
for
ap
in
aps
]
table
=
AsciiTable
([
header
]
+
rows
+
[
footer
])
table
.
inner_footing_row_border
=
True
print_log
(
'
\n
'
+
table
.
table
,
logger
=
logger
)
return
metrics
mmdet3d/core/evaluation/scannet_utils/evaluate_semantic_instance.py
0 → 100644
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
# adapted from https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/3d_evaluation/evaluate_semantic_instance.py # noqa
from
copy
import
deepcopy
import
numpy
as
np
from
.
import
util_3d
def
evaluate_matches
(
matches
,
class_labels
,
options
):
"""Evaluate instance segmentation from matched gt and predicted instances
for all scenes.
Args:
matches (dict): Contains gt2pred and pred2gt infos for every scene.
class_labels (tuple[str]): Class names.
options (dict): ScanNet evaluator options. See get_options.
Returns:
np.array: Average precision scores for all thresholds and categories.
"""
overlaps
=
options
[
'overlaps'
]
min_region_sizes
=
[
options
[
'min_region_sizes'
][
0
]]
dist_threshes
=
[
options
[
'distance_threshes'
][
0
]]
dist_confs
=
[
options
[
'distance_confs'
][
0
]]
# results: class x overlap
ap
=
np
.
zeros
((
len
(
dist_threshes
),
len
(
class_labels
),
len
(
overlaps
)),
np
.
float
)
for
di
,
(
min_region_size
,
distance_thresh
,
distance_conf
)
in
enumerate
(
zip
(
min_region_sizes
,
dist_threshes
,
dist_confs
)):
for
oi
,
overlap_th
in
enumerate
(
overlaps
):
pred_visited
=
{}
for
m
in
matches
:
for
label_name
in
class_labels
:
for
p
in
matches
[
m
][
'pred'
][
label_name
]:
if
'filename'
in
p
:
pred_visited
[
p
[
'filename'
]]
=
False
for
li
,
label_name
in
enumerate
(
class_labels
):
y_true
=
np
.
empty
(
0
)
y_score
=
np
.
empty
(
0
)
hard_false_negatives
=
0
has_gt
=
False
has_pred
=
False
for
m
in
matches
:
pred_instances
=
matches
[
m
][
'pred'
][
label_name
]
gt_instances
=
matches
[
m
][
'gt'
][
label_name
]
# filter groups in ground truth
gt_instances
=
[
gt
for
gt
in
gt_instances
if
gt
[
'instance_id'
]
>=
1000
and
gt
[
'vert_count'
]
>=
min_region_size
and
gt
[
'med_dist'
]
<=
distance_thresh
and
gt
[
'dist_conf'
]
>=
distance_conf
]
if
gt_instances
:
has_gt
=
True
if
pred_instances
:
has_pred
=
True
cur_true
=
np
.
ones
(
len
(
gt_instances
))
cur_score
=
np
.
ones
(
len
(
gt_instances
))
*
(
-
float
(
'inf'
))
cur_match
=
np
.
zeros
(
len
(
gt_instances
),
dtype
=
np
.
bool
)
# collect matches
for
(
gti
,
gt
)
in
enumerate
(
gt_instances
):
found_match
=
False
for
pred
in
gt
[
'matched_pred'
]:
# greedy assignments
if
pred_visited
[
pred
[
'filename'
]]:
continue
overlap
=
float
(
pred
[
'intersection'
])
/
(
gt
[
'vert_count'
]
+
pred
[
'vert_count'
]
-
pred
[
'intersection'
])
if
overlap
>
overlap_th
:
confidence
=
pred
[
'confidence'
]
# if already have a prediction for this gt,
# the prediction with the lower score is automatically a false positive # noqa
if
cur_match
[
gti
]:
max_score
=
max
(
cur_score
[
gti
],
confidence
)
min_score
=
min
(
cur_score
[
gti
],
confidence
)
cur_score
[
gti
]
=
max_score
# append false positive
cur_true
=
np
.
append
(
cur_true
,
0
)
cur_score
=
np
.
append
(
cur_score
,
min_score
)
cur_match
=
np
.
append
(
cur_match
,
True
)
# otherwise set score
else
:
found_match
=
True
cur_match
[
gti
]
=
True
cur_score
[
gti
]
=
confidence
pred_visited
[
pred
[
'filename'
]]
=
True
if
not
found_match
:
hard_false_negatives
+=
1
# remove non-matched ground truth instances
cur_true
=
cur_true
[
cur_match
]
cur_score
=
cur_score
[
cur_match
]
# collect non-matched predictions as false positive
for
pred
in
pred_instances
:
found_gt
=
False
for
gt
in
pred
[
'matched_gt'
]:
overlap
=
float
(
gt
[
'intersection'
])
/
(
gt
[
'vert_count'
]
+
pred
[
'vert_count'
]
-
gt
[
'intersection'
])
if
overlap
>
overlap_th
:
found_gt
=
True
break
if
not
found_gt
:
num_ignore
=
pred
[
'void_intersection'
]
for
gt
in
pred
[
'matched_gt'
]:
# group?
if
gt
[
'instance_id'
]
<
1000
:
num_ignore
+=
gt
[
'intersection'
]
# small ground truth instances
if
gt
[
'vert_count'
]
<
min_region_size
or
gt
[
'med_dist'
]
>
distance_thresh
or
gt
[
'dist_conf'
]
<
distance_conf
:
num_ignore
+=
gt
[
'intersection'
]
proportion_ignore
=
float
(
num_ignore
)
/
pred
[
'vert_count'
]
# if not ignored append false positive
if
proportion_ignore
<=
overlap_th
:
cur_true
=
np
.
append
(
cur_true
,
0
)
confidence
=
pred
[
'confidence'
]
cur_score
=
np
.
append
(
cur_score
,
confidence
)
# append to overall results
y_true
=
np
.
append
(
y_true
,
cur_true
)
y_score
=
np
.
append
(
y_score
,
cur_score
)
# compute average precision
if
has_gt
and
has_pred
:
# compute precision recall curve first
# sorting and cumsum
score_arg_sort
=
np
.
argsort
(
y_score
)
y_score_sorted
=
y_score
[
score_arg_sort
]
y_true_sorted
=
y_true
[
score_arg_sort
]
y_true_sorted_cumsum
=
np
.
cumsum
(
y_true_sorted
)
# unique thresholds
(
thresholds
,
unique_indices
)
=
np
.
unique
(
y_score_sorted
,
return_index
=
True
)
num_prec_recall
=
len
(
unique_indices
)
+
1
# prepare precision recall
num_examples
=
len
(
y_score_sorted
)
# follow https://github.com/ScanNet/ScanNet/pull/26 ? # noqa
num_true_examples
=
y_true_sorted_cumsum
[
-
1
]
if
len
(
y_true_sorted_cumsum
)
>
0
else
0
precision
=
np
.
zeros
(
num_prec_recall
)
recall
=
np
.
zeros
(
num_prec_recall
)
# deal with the first point
y_true_sorted_cumsum
=
np
.
append
(
y_true_sorted_cumsum
,
0
)
# deal with remaining
for
idx_res
,
idx_scores
in
enumerate
(
unique_indices
):
cumsum
=
y_true_sorted_cumsum
[
idx_scores
-
1
]
tp
=
num_true_examples
-
cumsum
fp
=
num_examples
-
idx_scores
-
tp
fn
=
cumsum
+
hard_false_negatives
p
=
float
(
tp
)
/
(
tp
+
fp
)
r
=
float
(
tp
)
/
(
tp
+
fn
)
precision
[
idx_res
]
=
p
recall
[
idx_res
]
=
r
# first point in curve is artificial
precision
[
-
1
]
=
1.
recall
[
-
1
]
=
0.
# compute average of precision-recall curve
recall_for_conv
=
np
.
copy
(
recall
)
recall_for_conv
=
np
.
append
(
recall_for_conv
[
0
],
recall_for_conv
)
recall_for_conv
=
np
.
append
(
recall_for_conv
,
0.
)
stepWidths
=
np
.
convolve
(
recall_for_conv
,
[
-
0.5
,
0
,
0.5
],
'valid'
)
# integrate is now simply a dot product
ap_current
=
np
.
dot
(
precision
,
stepWidths
)
elif
has_gt
:
ap_current
=
0.0
else
:
ap_current
=
float
(
'nan'
)
ap
[
di
,
li
,
oi
]
=
ap_current
return
ap
def
compute_averages
(
aps
,
options
,
class_labels
):
"""Averages AP scores for all categories.
Args:
aps (np.array): AP scores for all thresholds and categories.
options (dict): ScanNet evaluator options. See get_options.
class_labels (tuple[str]): Class names.
Returns:
dict: Overall and per-category AP scores.
"""
d_inf
=
0
o50
=
np
.
where
(
np
.
isclose
(
options
[
'overlaps'
],
0.5
))
o25
=
np
.
where
(
np
.
isclose
(
options
[
'overlaps'
],
0.25
))
o_all_but25
=
np
.
where
(
np
.
logical_not
(
np
.
isclose
(
options
[
'overlaps'
],
0.25
)))
avg_dict
=
{}
avg_dict
[
'all_ap'
]
=
np
.
nanmean
(
aps
[
d_inf
,
:,
o_all_but25
])
avg_dict
[
'all_ap_50%'
]
=
np
.
nanmean
(
aps
[
d_inf
,
:,
o50
])
avg_dict
[
'all_ap_25%'
]
=
np
.
nanmean
(
aps
[
d_inf
,
:,
o25
])
avg_dict
[
'classes'
]
=
{}
for
(
li
,
label_name
)
in
enumerate
(
class_labels
):
avg_dict
[
'classes'
][
label_name
]
=
{}
avg_dict
[
'classes'
][
label_name
][
'ap'
]
=
np
.
average
(
aps
[
d_inf
,
li
,
o_all_but25
])
avg_dict
[
'classes'
][
label_name
][
'ap50%'
]
=
np
.
average
(
aps
[
d_inf
,
li
,
o50
])
avg_dict
[
'classes'
][
label_name
][
'ap25%'
]
=
np
.
average
(
aps
[
d_inf
,
li
,
o25
])
return
avg_dict
def
assign_instances_for_scan
(
pred_info
,
gt_ids
,
options
,
valid_class_ids
,
class_labels
,
id_to_label
):
"""Assign gt and predicted instances for a single scene.
Args:
pred_info (dict): Predicted masks, labels and scores.
gt_ids (np.array): Ground truth instance masks.
options (dict): ScanNet evaluator options. See get_options.
valid_class_ids (tuple[int]): Ids of valid categories.
class_labels (tuple[str]): Class names.
id_to_label (dict[int, str]): Mapping of valid class id to class label.
Returns:
dict: Per class assigned gt to predicted instances.
dict: Per class assigned predicted to gt instances.
"""
# get gt instances
gt_instances
=
util_3d
.
get_instances
(
gt_ids
,
valid_class_ids
,
class_labels
,
id_to_label
)
# associate
gt2pred
=
deepcopy
(
gt_instances
)
for
label
in
gt2pred
:
for
gt
in
gt2pred
[
label
]:
gt
[
'matched_pred'
]
=
[]
pred2gt
=
{}
for
label
in
class_labels
:
pred2gt
[
label
]
=
[]
num_pred_instances
=
0
# mask of void labels in the ground truth
bool_void
=
np
.
logical_not
(
np
.
in1d
(
gt_ids
//
1000
,
valid_class_ids
))
# go through all prediction masks
for
pred_mask_file
in
pred_info
:
label_id
=
int
(
pred_info
[
pred_mask_file
][
'label_id'
])
conf
=
pred_info
[
pred_mask_file
][
'conf'
]
if
not
label_id
in
id_to_label
:
# noqa E713
continue
label_name
=
id_to_label
[
label_id
]
# read the mask
pred_mask
=
pred_info
[
pred_mask_file
][
'mask'
]
if
len
(
pred_mask
)
!=
len
(
gt_ids
):
raise
ValueError
(
'len(pred_mask) != len(gt_ids)'
)
# convert to binary
pred_mask
=
np
.
not_equal
(
pred_mask
,
0
)
num
=
np
.
count_nonzero
(
pred_mask
)
if
num
<
options
[
'min_region_sizes'
][
0
]:
continue
# skip if empty
pred_instance
=
{}
pred_instance
[
'filename'
]
=
pred_mask_file
pred_instance
[
'pred_id'
]
=
num_pred_instances
pred_instance
[
'label_id'
]
=
label_id
pred_instance
[
'vert_count'
]
=
num
pred_instance
[
'confidence'
]
=
conf
pred_instance
[
'void_intersection'
]
=
np
.
count_nonzero
(
np
.
logical_and
(
bool_void
,
pred_mask
))
# matched gt instances
matched_gt
=
[]
# go through all gt instances with matching label
for
(
gt_num
,
gt_inst
)
in
enumerate
(
gt2pred
[
label_name
]):
intersection
=
np
.
count_nonzero
(
np
.
logical_and
(
gt_ids
==
gt_inst
[
'instance_id'
],
pred_mask
))
if
intersection
>
0
:
gt_copy
=
gt_inst
.
copy
()
pred_copy
=
pred_instance
.
copy
()
gt_copy
[
'intersection'
]
=
intersection
pred_copy
[
'intersection'
]
=
intersection
matched_gt
.
append
(
gt_copy
)
gt2pred
[
label_name
][
gt_num
][
'matched_pred'
].
append
(
pred_copy
)
pred_instance
[
'matched_gt'
]
=
matched_gt
num_pred_instances
+=
1
pred2gt
[
label_name
].
append
(
pred_instance
)
return
gt2pred
,
pred2gt
def
scannet_eval
(
preds
,
gts
,
options
,
valid_class_ids
,
class_labels
,
id_to_label
):
"""Evaluate instance segmentation in ScanNet protocol.
Args:
preds (list[dict]): Per scene predictions of mask, label and
confidence.
gts (list[np.array]): Per scene ground truth instance masks.
options (dict): ScanNet evaluator options. See get_options.
valid_class_ids (tuple[int]): Ids of valid categories.
class_labels (tuple[str]): Class names.
id_to_label (dict[int, str]): Mapping of valid class id to class label.
Returns:
dict: Overall and per-category AP scores.
"""
options
=
get_options
(
options
)
matches
=
{}
for
i
,
(
pred
,
gt
)
in
enumerate
(
zip
(
preds
,
gts
)):
matches_key
=
i
# assign gt to predictions
gt2pred
,
pred2gt
=
assign_instances_for_scan
(
pred
,
gt
,
options
,
valid_class_ids
,
class_labels
,
id_to_label
)
matches
[
matches_key
]
=
{}
matches
[
matches_key
][
'gt'
]
=
gt2pred
matches
[
matches_key
][
'pred'
]
=
pred2gt
ap_scores
=
evaluate_matches
(
matches
,
class_labels
,
options
)
avgs
=
compute_averages
(
ap_scores
,
options
,
class_labels
)
return
avgs
def
get_options
(
options
=
None
):
"""Set ScanNet evaluator options.
Args:
options (dict, optional): Not default options. Default: None.
Returns:
dict: Updated options with all 4 keys.
"""
assert
options
is
None
or
isinstance
(
options
,
dict
)
_options
=
dict
(
overlaps
=
np
.
append
(
np
.
arange
(
0.5
,
0.95
,
0.05
),
0.25
),
min_region_sizes
=
np
.
array
([
100
]),
distance_threshes
=
np
.
array
([
float
(
'inf'
)]),
distance_confs
=
np
.
array
([
-
float
(
'inf'
)]))
if
options
is
not
None
:
_options
.
update
(
options
)
return
_options
mmdet3d/core/evaluation/scannet_utils/util_3d.py
0 → 100644
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
# adapted from https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/util_3d.py # noqa
import
json
import
numpy
as
np
class
Instance
:
"""Single instance for ScanNet evaluator.
Args:
mesh_vert_instances (np.array): Instance ids for each point.
instance_id: Id of single instance.
"""
instance_id
=
0
label_id
=
0
vert_count
=
0
med_dist
=
-
1
dist_conf
=
0.0
def
__init__
(
self
,
mesh_vert_instances
,
instance_id
):
if
instance_id
==
-
1
:
return
self
.
instance_id
=
int
(
instance_id
)
self
.
label_id
=
int
(
self
.
get_label_id
(
instance_id
))
self
.
vert_count
=
int
(
self
.
get_instance_verts
(
mesh_vert_instances
,
instance_id
))
@
staticmethod
def
get_label_id
(
instance_id
):
return
int
(
instance_id
//
1000
)
@
staticmethod
def
get_instance_verts
(
mesh_vert_instances
,
instance_id
):
return
(
mesh_vert_instances
==
instance_id
).
sum
()
def
to_json
(
self
):
return
json
.
dumps
(
self
,
default
=
lambda
o
:
o
.
__dict__
,
sort_keys
=
True
,
indent
=
4
)
def
to_dict
(
self
):
dict
=
{}
dict
[
'instance_id'
]
=
self
.
instance_id
dict
[
'label_id'
]
=
self
.
label_id
dict
[
'vert_count'
]
=
self
.
vert_count
dict
[
'med_dist'
]
=
self
.
med_dist
dict
[
'dist_conf'
]
=
self
.
dist_conf
return
dict
def
from_json
(
self
,
data
):
self
.
instance_id
=
int
(
data
[
'instance_id'
])
self
.
label_id
=
int
(
data
[
'label_id'
])
self
.
vert_count
=
int
(
data
[
'vert_count'
])
if
'med_dist'
in
data
:
self
.
med_dist
=
float
(
data
[
'med_dist'
])
self
.
dist_conf
=
float
(
data
[
'dist_conf'
])
def
__str__
(
self
):
return
'('
+
str
(
self
.
instance_id
)
+
')'
def
get_instances
(
ids
,
class_ids
,
class_labels
,
id2label
):
"""Transform gt instance mask to Instance objects.
Args:
ids (np.array): Instance ids for each point.
class_ids: (tuple[int]): Ids of valid categories.
class_labels (tuple[str]): Class names.
id2label: (dict[int, str]): Mapping of valid class id to class label.
Returns:
dict [str, list]: Instance objects grouped by class label.
"""
instances
=
{}
for
label
in
class_labels
:
instances
[
label
]
=
[]
instance_ids
=
np
.
unique
(
ids
)
for
id
in
instance_ids
:
if
id
==
0
:
continue
inst
=
Instance
(
ids
,
id
)
if
inst
.
label_id
in
class_ids
:
instances
[
id2label
[
inst
.
label_id
]].
append
(
inst
.
to_dict
())
return
instances
mmdet3d/core/points/base_points.py
View file @
333536f6
...
@@ -242,7 +242,7 @@ class BasePoints(object):
...
@@ -242,7 +242,7 @@ class BasePoints(object):
"""
"""
in_range_flags
=
((
self
.
bev
[:,
0
]
>
point_range
[
0
])
in_range_flags
=
((
self
.
bev
[:,
0
]
>
point_range
[
0
])
&
(
self
.
bev
[:,
1
]
>
point_range
[
1
])
&
(
self
.
bev
[:,
1
]
>
point_range
[
1
])
&
(
self
.
bev
[:,
1
]
<
point_range
[
2
])
&
(
self
.
bev
[:,
0
]
<
point_range
[
2
])
&
(
self
.
bev
[:,
1
]
<
point_range
[
3
]))
&
(
self
.
bev
[:,
1
]
<
point_range
[
3
]))
return
in_range_flags
return
in_range_flags
...
...
mmdet3d/core/post_processing/box3d_nms.py
View file @
333536f6
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
import
numba
import
numba
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
mmcv.ops
import
nms_bev
as
nms_gpu
from
mm
det3d.ops.iou3d.iou3d_utils
import
nms_gpu
,
nms_normal_gpu
from
mm
cv.ops
import
nms_normal_bev
as
nms_normal_gpu
def
box3d_multiclass_nms
(
mlvl_bboxes
,
def
box3d_multiclass_nms
(
mlvl_bboxes
,
...
...
mmdet3d/core/post_processing/merge_augs.py
View file @
333536f6
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch
from
mmcv.ops
import
nms_bev
as
nms_gpu
from
mmcv.ops
import
nms_normal_bev
as
nms_normal_gpu
from
mmdet3d.ops.iou3d.iou3d_utils
import
nms_gpu
,
nms_normal_gpu
from
..bbox
import
bbox3d2result
,
bbox3d_mapping_back
,
xywhr2xyxyr
from
..bbox
import
bbox3d2result
,
bbox3d_mapping_back
,
xywhr2xyxyr
...
...
mmdet3d/datasets/__init__.py
View file @
333536f6
...
@@ -21,7 +21,8 @@ from .pipelines import (AffineResize, BackgroundPointsFilter, GlobalAlignment,
...
@@ -21,7 +21,8 @@ from .pipelines import (AffineResize, BackgroundPointsFilter, GlobalAlignment,
VoxelBasedPointSampler
)
VoxelBasedPointSampler
)
# yapf: enable
# yapf: enable
from
.s3dis_dataset
import
S3DISDataset
,
S3DISSegDataset
from
.s3dis_dataset
import
S3DISDataset
,
S3DISSegDataset
from
.scannet_dataset
import
ScanNetDataset
,
ScanNetSegDataset
from
.scannet_dataset
import
(
ScanNetDataset
,
ScanNetInstanceSegDataset
,
ScanNetSegDataset
)
from
.semantickitti_dataset
import
SemanticKITTIDataset
from
.semantickitti_dataset
import
SemanticKITTIDataset
from
.sunrgbd_dataset
import
SUNRGBDDataset
from
.sunrgbd_dataset
import
SUNRGBDDataset
from
.utils
import
get_loading_pipeline
from
.utils
import
get_loading_pipeline
...
@@ -35,10 +36,10 @@ __all__ = [
...
@@ -35,10 +36,10 @@ __all__ = [
'LoadPointsFromFile'
,
'S3DISSegDataset'
,
'S3DISDataset'
,
'LoadPointsFromFile'
,
'S3DISSegDataset'
,
'S3DISDataset'
,
'NormalizePointsColor'
,
'IndoorPatchPointSample'
,
'IndoorPointSample'
,
'NormalizePointsColor'
,
'IndoorPatchPointSample'
,
'IndoorPointSample'
,
'PointSample'
,
'LoadAnnotations3D'
,
'GlobalAlignment'
,
'SUNRGBDDataset'
,
'PointSample'
,
'LoadAnnotations3D'
,
'GlobalAlignment'
,
'SUNRGBDDataset'
,
'ScanNetDataset'
,
'ScanNetSegDataset'
,
'S
emanticKITTI
Dataset'
,
'ScanNetDataset'
,
'ScanNetSegDataset'
,
'S
canNetInstanceSeg
Dataset'
,
'Custom3DDataset'
,
'Custom3DSegDataset'
,
'LoadPointsFromMultiSweeps'
,
'SemanticKITTIDataset'
,
'Custom3DDataset'
,
'Custom3DSegDataset'
,
'WaymoDataset'
,
'BackgroundPointsFilter'
,
'VoxelBasedPointSampler'
,
'LoadPointsFromMultiSweeps'
,
'WaymoDataset'
,
'BackgroundPointsFilter'
,
'get_loading_pipeline'
,
'RandomDropPointsColor'
,
'RandomJitterPoints'
,
'VoxelBasedPointSampler'
,
'get_loading_pipeline'
,
'RandomDropPointsColor'
,
'ObjectNameFilter'
,
'AffineResize'
,
'RandomShiftScale'
,
'RandomJitterPoints'
,
'ObjectNameFilter'
,
'AffineResize'
,
'LoadPointsFromDict'
'RandomShiftScale'
,
'LoadPointsFromDict'
]
]
mmdet3d/datasets/custom_3d.py
View file @
333536f6
...
@@ -51,7 +51,8 @@ class Custom3DDataset(Dataset):
...
@@ -51,7 +51,8 @@ class Custom3DDataset(Dataset):
modality
=
None
,
modality
=
None
,
box_type_3d
=
'LiDAR'
,
box_type_3d
=
'LiDAR'
,
filter_empty_gt
=
True
,
filter_empty_gt
=
True
,
test_mode
=
False
):
test_mode
=
False
,
file_client_args
=
dict
(
backend
=
'disk'
)):
super
().
__init__
()
super
().
__init__
()
self
.
data_root
=
data_root
self
.
data_root
=
data_root
self
.
ann_file
=
ann_file
self
.
ann_file
=
ann_file
...
@@ -61,13 +62,26 @@ class Custom3DDataset(Dataset):
...
@@ -61,13 +62,26 @@ class Custom3DDataset(Dataset):
self
.
box_type_3d
,
self
.
box_mode_3d
=
get_box_type
(
box_type_3d
)
self
.
box_type_3d
,
self
.
box_mode_3d
=
get_box_type
(
box_type_3d
)
self
.
CLASSES
=
self
.
get_classes
(
classes
)
self
.
CLASSES
=
self
.
get_classes
(
classes
)
self
.
file_client
=
mmcv
.
FileClient
(
**
file_client_args
)
self
.
cat2id
=
{
name
:
i
for
i
,
name
in
enumerate
(
self
.
CLASSES
)}
self
.
cat2id
=
{
name
:
i
for
i
,
name
in
enumerate
(
self
.
CLASSES
)}
self
.
data_infos
=
self
.
load_annotations
(
self
.
ann_file
)
# load annotations
if
hasattr
(
self
.
file_client
,
'get_local_path'
):
with
self
.
file_client
.
get_local_path
(
self
.
ann_file
)
as
local_path
:
self
.
data_infos
=
self
.
load_annotations
(
open
(
local_path
,
'rb'
))
else
:
warnings
.
warn
(
'The used MMCV version does not have get_local_path. '
f
'We treat the
{
self
.
ann_file
}
as local paths and it '
'might cause errors if the path is not a local path. '
'Please use MMCV>= 1.3.16 if you meet errors.'
)
self
.
data_infos
=
self
.
load_annotations
(
self
.
ann_file
)
# process pipeline
if
pipeline
is
not
None
:
if
pipeline
is
not
None
:
self
.
pipeline
=
Compose
(
pipeline
)
self
.
pipeline
=
Compose
(
pipeline
)
# set group flag for the sampler
# set group flag for the sampler
s
if
not
self
.
test_mode
:
if
not
self
.
test_mode
:
self
.
_set_group_flag
()
self
.
_set_group_flag
()
...
@@ -80,7 +94,8 @@ class Custom3DDataset(Dataset):
...
@@ -80,7 +94,8 @@ class Custom3DDataset(Dataset):
Returns:
Returns:
list[dict]: List of annotations.
list[dict]: List of annotations.
"""
"""
return
mmcv
.
load
(
ann_file
)
# loading data from a file-like object needs file format
return
mmcv
.
load
(
ann_file
,
file_format
=
'pkl'
)
def
get_data_info
(
self
,
index
):
def
get_data_info
(
self
,
index
):
"""Get data info according to the given index.
"""Get data info according to the given index.
...
...
mmdet3d/datasets/custom_3d_seg.py
View file @
333536f6
...
@@ -62,14 +62,26 @@ class Custom3DSegDataset(Dataset):
...
@@ -62,14 +62,26 @@ class Custom3DSegDataset(Dataset):
modality
=
None
,
modality
=
None
,
test_mode
=
False
,
test_mode
=
False
,
ignore_index
=
None
,
ignore_index
=
None
,
scene_idxs
=
None
):
scene_idxs
=
None
,
file_client_args
=
dict
(
backend
=
'disk'
)):
super
().
__init__
()
super
().
__init__
()
self
.
data_root
=
data_root
self
.
data_root
=
data_root
self
.
ann_file
=
ann_file
self
.
ann_file
=
ann_file
self
.
test_mode
=
test_mode
self
.
test_mode
=
test_mode
self
.
modality
=
modality
self
.
modality
=
modality
self
.
file_client
=
mmcv
.
FileClient
(
**
file_client_args
)
self
.
data_infos
=
self
.
load_annotations
(
self
.
ann_file
)
# load annotations
if
hasattr
(
self
.
file_client
,
'get_local_path'
):
with
self
.
file_client
.
get_local_path
(
self
.
ann_file
)
as
local_path
:
self
.
data_infos
=
self
.
load_annotations
(
open
(
local_path
,
'rb'
))
else
:
warnings
.
warn
(
'The used MMCV version does not have get_local_path. '
f
'We treat the
{
self
.
ann_file
}
as local paths and it '
'might cause errors if the path is not a local path. '
'Please use MMCV>= 1.3.16 if you meet errors.'
)
self
.
data_infos
=
self
.
load_annotations
(
self
.
ann_file
)
if
pipeline
is
not
None
:
if
pipeline
is
not
None
:
self
.
pipeline
=
Compose
(
pipeline
)
self
.
pipeline
=
Compose
(
pipeline
)
...
@@ -94,7 +106,8 @@ class Custom3DSegDataset(Dataset):
...
@@ -94,7 +106,8 @@ class Custom3DSegDataset(Dataset):
Returns:
Returns:
list[dict]: List of annotations.
list[dict]: List of annotations.
"""
"""
return
mmcv
.
load
(
ann_file
)
# loading data from a file-like object needs file format
return
mmcv
.
load
(
ann_file
,
file_format
=
'pkl'
)
def
get_data_info
(
self
,
index
):
def
get_data_info
(
self
,
index
):
"""Get data info according to the given index.
"""Get data info according to the given index.
...
...
mmdet3d/datasets/kitti_dataset.py
View file @
333536f6
...
@@ -65,7 +65,8 @@ class KittiDataset(Custom3DDataset):
...
@@ -65,7 +65,8 @@ class KittiDataset(Custom3DDataset):
box_type_3d
=
'LiDAR'
,
box_type_3d
=
'LiDAR'
,
filter_empty_gt
=
True
,
filter_empty_gt
=
True
,
test_mode
=
False
,
test_mode
=
False
,
pcd_limit_range
=
[
0
,
-
40
,
-
3
,
70.4
,
40
,
0.0
]):
pcd_limit_range
=
[
0
,
-
40
,
-
3
,
70.4
,
40
,
0.0
],
**
kwargs
):
super
().
__init__
(
super
().
__init__
(
data_root
=
data_root
,
data_root
=
data_root
,
ann_file
=
ann_file
,
ann_file
=
ann_file
,
...
@@ -74,7 +75,8 @@ class KittiDataset(Custom3DDataset):
...
@@ -74,7 +75,8 @@ class KittiDataset(Custom3DDataset):
modality
=
modality
,
modality
=
modality
,
box_type_3d
=
box_type_3d
,
box_type_3d
=
box_type_3d
,
filter_empty_gt
=
filter_empty_gt
,
filter_empty_gt
=
filter_empty_gt
,
test_mode
=
test_mode
)
test_mode
=
test_mode
,
**
kwargs
)
self
.
split
=
split
self
.
split
=
split
self
.
root_split
=
os
.
path
.
join
(
self
.
data_root
,
split
)
self
.
root_split
=
os
.
path
.
join
(
self
.
data_root
,
split
)
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment