Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
679741d0
"...ops/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "1a3a1adea3975adbd2e27770bf174cff3c7a3df3"
Commit
679741d0
authored
Jun 10, 2020
by
zhangwenwei
Browse files
use box3dstructure in augmentation
parent
4b4cc7e8
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
254 additions
and
47 deletions
+254
-47
configs/kitti/dv_pointpillars_secfpn_6x8_160e_kitti-3d-car.py
...igs/kitti/dv_pointpillars_secfpn_6x8_160e_kitti-3d-car.py
+1
-0
configs/kitti/dv_second_secfpn_2x8_cosine_80e_kitti-3d-3class.py
.../kitti/dv_second_secfpn_2x8_cosine_80e_kitti-3d-3class.py
+1
-0
configs/kitti/dv_second_secfpn_6x8_80e_kitti-3d-car.py
configs/kitti/dv_second_secfpn_6x8_80e_kitti-3d-car.py
+1
-0
configs/kitti/hv_PartA2_secfpn_4x8_cyclic_80e_kitti-3d-3class.py
.../kitti/hv_PartA2_secfpn_4x8_cyclic_80e_kitti-3d-3class.py
+1
-0
configs/kitti/hv_PartA2_secfpn_4x8_cyclic_80e_kitti-3d-car.py
...igs/kitti/hv_PartA2_secfpn_4x8_cyclic_80e_kitti-3d-car.py
+1
-0
configs/kitti/hv_pointpillars_secfpn_6x8_160e_kitti-3d-car.py
...igs/kitti/hv_pointpillars_secfpn_6x8_160e_kitti-3d-car.py
+1
-0
configs/kitti/hv_second_secfpn_6x8_80e_kitti-3d-car.py
configs/kitti/hv_second_secfpn_6x8_80e_kitti-3d-car.py
+1
-0
mmdet3d/core/bbox/__init__.py
mmdet3d/core/bbox/__init__.py
+3
-2
mmdet3d/core/bbox/structures/__init__.py
mmdet3d/core/bbox/structures/__init__.py
+3
-2
mmdet3d/core/bbox/structures/base_box3d.py
mmdet3d/core/bbox/structures/base_box3d.py
+18
-0
mmdet3d/core/bbox/structures/box_3d_mode.py
mmdet3d/core/bbox/structures/box_3d_mode.py
+1
-1
mmdet3d/core/bbox/structures/cam_box3d.py
mmdet3d/core/bbox/structures/cam_box3d.py
+19
-0
mmdet3d/core/bbox/structures/depth_box3d.py
mmdet3d/core/bbox/structures/depth_box3d.py
+19
-0
mmdet3d/core/bbox/structures/lidar_box3d.py
mmdet3d/core/bbox/structures/lidar_box3d.py
+23
-0
mmdet3d/datasets/kitti_dataset.py
mmdet3d/datasets/kitti_dataset.py
+6
-5
mmdet3d/datasets/pipelines/formating.py
mmdet3d/datasets/pipelines/formating.py
+11
-3
mmdet3d/datasets/pipelines/train_aug.py
mmdet3d/datasets/pipelines/train_aug.py
+21
-34
mmdet3d/models/dense_heads/train_mixins.py
mmdet3d/models/dense_heads/train_mixins.py
+2
-0
tests/test_pipeline/test_outdoor_pipeline.py
tests/test_pipeline/test_outdoor_pipeline.py
+121
-0
No files found.
configs/kitti/dv_pointpillars_secfpn_6x8_160e_kitti-3d-car.py
View file @
679741d0
...
@@ -189,6 +189,7 @@ momentum_config = dict(
...
@@ -189,6 +189,7 @@ momentum_config = dict(
step_ratio_up
=
0.4
,
step_ratio_up
=
0.4
,
)
)
checkpoint_config
=
dict
(
interval
=
1
)
checkpoint_config
=
dict
(
interval
=
1
)
evaluation
=
dict
(
interval
=
2
)
# yapf:disable
# yapf:disable
log_config
=
dict
(
log_config
=
dict
(
interval
=
50
,
interval
=
50
,
...
...
configs/kitti/dv_second_secfpn_2x8_cosine_80e_kitti-3d-3class.py
View file @
679741d0
...
@@ -213,6 +213,7 @@ lr_config = dict(
...
@@ -213,6 +213,7 @@ lr_config = dict(
min_lr_ratio
=
1e-5
)
min_lr_ratio
=
1e-5
)
momentum_config
=
None
momentum_config
=
None
checkpoint_config
=
dict
(
interval
=
1
)
checkpoint_config
=
dict
(
interval
=
1
)
evaluation
=
dict
(
interval
=
2
)
# yapf:disable
# yapf:disable
log_config
=
dict
(
log_config
=
dict
(
interval
=
50
,
interval
=
50
,
...
...
configs/kitti/dv_second_secfpn_6x8_80e_kitti-3d-car.py
View file @
679741d0
...
@@ -183,6 +183,7 @@ momentum_config = dict(
...
@@ -183,6 +183,7 @@ momentum_config = dict(
step_ratio_up
=
0.4
,
step_ratio_up
=
0.4
,
)
)
checkpoint_config
=
dict
(
interval
=
1
)
checkpoint_config
=
dict
(
interval
=
1
)
evaluation
=
dict
(
interval
=
2
)
# yapf:disable
# yapf:disable
log_config
=
dict
(
log_config
=
dict
(
interval
=
50
,
interval
=
50
,
...
...
configs/kitti/hv_PartA2_secfpn_4x8_cyclic_80e_kitti-3d-3class.py
View file @
679741d0
...
@@ -302,6 +302,7 @@ momentum_config = dict(
...
@@ -302,6 +302,7 @@ momentum_config = dict(
cyclic_times
=
1
,
cyclic_times
=
1
,
step_ratio_up
=
0.4
)
step_ratio_up
=
0.4
)
checkpoint_config
=
dict
(
interval
=
1
)
checkpoint_config
=
dict
(
interval
=
1
)
evaluation
=
dict
(
interval
=
2
)
# yapf:disable
# yapf:disable
log_config
=
dict
(
log_config
=
dict
(
interval
=
50
,
interval
=
50
,
...
...
configs/kitti/hv_PartA2_secfpn_4x8_cyclic_80e_kitti-3d-car.py
View file @
679741d0
...
@@ -261,6 +261,7 @@ momentum_config = dict(
...
@@ -261,6 +261,7 @@ momentum_config = dict(
cyclic_times
=
1
,
cyclic_times
=
1
,
step_ratio_up
=
0.4
)
step_ratio_up
=
0.4
)
checkpoint_config
=
dict
(
interval
=
1
)
checkpoint_config
=
dict
(
interval
=
1
)
evaluation
=
dict
(
interval
=
2
)
# yapf:disable
# yapf:disable
log_config
=
dict
(
log_config
=
dict
(
interval
=
50
,
interval
=
50
,
...
...
configs/kitti/hv_pointpillars_secfpn_6x8_160e_kitti-3d-car.py
View file @
679741d0
...
@@ -192,6 +192,7 @@ momentum_config = dict(
...
@@ -192,6 +192,7 @@ momentum_config = dict(
step_ratio_up
=
0.4
,
step_ratio_up
=
0.4
,
)
)
checkpoint_config
=
dict
(
interval
=
1
)
checkpoint_config
=
dict
(
interval
=
1
)
evaluation
=
dict
(
interval
=
2
)
# yapf:disable
# yapf:disable
log_config
=
dict
(
log_config
=
dict
(
interval
=
50
,
interval
=
50
,
...
...
configs/kitti/hv_second_secfpn_6x8_80e_kitti-3d-car.py
View file @
679741d0
...
@@ -197,6 +197,7 @@ momentum_config = dict(
...
@@ -197,6 +197,7 @@ momentum_config = dict(
step_ratio_up
=
0.4
,
step_ratio_up
=
0.4
,
)
)
checkpoint_config
=
dict
(
interval
=
1
)
checkpoint_config
=
dict
(
interval
=
1
)
evaluation
=
dict
(
interval
=
2
)
# yapf:disable
# yapf:disable
log_config
=
dict
(
log_config
=
dict
(
interval
=
50
,
interval
=
50
,
...
...
mmdet3d/core/bbox/__init__.py
View file @
679741d0
...
@@ -7,7 +7,7 @@ from .iou_calculators import (BboxOverlaps3D, BboxOverlapsNearest3D,
...
@@ -7,7 +7,7 @@ from .iou_calculators import (BboxOverlaps3D, BboxOverlapsNearest3D,
from
.samplers
import
(
BaseSampler
,
CombinedSampler
,
from
.samplers
import
(
BaseSampler
,
CombinedSampler
,
InstanceBalancedPosSampler
,
IoUBalancedNegSampler
,
InstanceBalancedPosSampler
,
IoUBalancedNegSampler
,
PseudoSampler
,
RandomSampler
,
SamplingResult
)
PseudoSampler
,
RandomSampler
,
SamplingResult
)
from
.structures
import
(
Box3DMode
,
CameraInstance3DBoxes
,
from
.structures
import
(
BaseInstance3DBoxes
,
Box3DMode
,
CameraInstance3DBoxes
,
DepthInstance3DBoxes
,
LiDARInstance3DBoxes
)
DepthInstance3DBoxes
,
LiDARInstance3DBoxes
)
from
.transforms
import
(
bbox3d2result
,
bbox3d2roi
,
from
.transforms
import
(
bbox3d2result
,
bbox3d2roi
,
box3d_to_corner3d_upright_depth
,
box3d_to_corner3d_upright_depth
,
...
@@ -26,5 +26,6 @@ __all__ = [
...
@@ -26,5 +26,6 @@ __all__ = [
'BboxOverlapsNearest3D'
,
'BboxOverlaps3D'
,
'bbox_overlaps_nearest_3d'
,
'BboxOverlapsNearest3D'
,
'BboxOverlaps3D'
,
'bbox_overlaps_nearest_3d'
,
'bbox_overlaps_3d'
,
'Box3DMode'
,
'LiDARInstance3DBoxes'
,
'bbox_overlaps_3d'
,
'Box3DMode'
,
'LiDARInstance3DBoxes'
,
'CameraInstance3DBoxes'
,
'bbox3d2roi'
,
'bbox3d2result'
,
'CameraInstance3DBoxes'
,
'bbox3d2roi'
,
'bbox3d2result'
,
'box3d_to_corner3d_upright_depth'
,
'DepthInstance3DBoxes'
'box3d_to_corner3d_upright_depth'
,
'DepthInstance3DBoxes'
,
'BaseInstance3DBoxes'
]
]
mmdet3d/core/bbox/structures/__init__.py
View file @
679741d0
from
.base_box3d
import
BaseInstance3DBoxes
from
.box_3d_mode
import
Box3DMode
from
.box_3d_mode
import
Box3DMode
from
.cam_box3d
import
CameraInstance3DBoxes
from
.cam_box3d
import
CameraInstance3DBoxes
from
.depth_box3d
import
DepthInstance3DBoxes
from
.depth_box3d
import
DepthInstance3DBoxes
from
.lidar_box3d
import
LiDARInstance3DBoxes
from
.lidar_box3d
import
LiDARInstance3DBoxes
__all__
=
[
__all__
=
[
'Box3DMode'
,
'
LiDAR
Instance3DBoxes'
,
'
Camera
Instance3DBoxes'
,
'Box3DMode'
,
'
Base
Instance3DBoxes'
,
'
LiDAR
Instance3DBoxes'
,
'DepthInstance3DBoxes'
'CameraInstance3DBoxes'
,
'DepthInstance3DBoxes'
]
]
mmdet3d/core/bbox/structures/base_box3d.py
View file @
679741d0
...
@@ -199,6 +199,24 @@ class BaseInstance3DBoxes(object):
...
@@ -199,6 +199,24 @@ class BaseInstance3DBoxes(object):
"""
"""
pass
pass
@
abstractmethod
def
convert_to
(
self
,
dst
,
rt_mat
=
None
):
"""Convert self to `dst` mode.
Args:
dst (BoxMode): the target Box mode
rt_mat (np.ndarray | torch.Tensor): The rotation and translation
matrix between different coordinates. Defaults to None.
The conversion from `src` coordinates to `dst` coordinates
usually comes along the change of sensors, e.g., from camera
to LiDAR. This requires a transformation matrix.
Returns:
BaseInstance3DBoxes:
The converted box of the same type in the `dst` mode.
"""
pass
def
scale
(
self
,
scale_factor
):
def
scale
(
self
,
scale_factor
):
"""Scale the box with horizontal and vertical scaling factors
"""Scale the box with horizontal and vertical scaling factors
...
...
mmdet3d/core/bbox/structures/box_3d_mode.py
View file @
679741d0
...
@@ -74,7 +74,7 @@ class Box3DMode(IntEnum):
...
@@ -74,7 +74,7 @@ class Box3DMode(IntEnum):
to LiDAR. This requires a transformation matrix.
to LiDAR. This requires a transformation matrix.
Returns:
Returns:
(tuple | list | np.ndarray | torch.Tensor):
(tuple | list | np.ndarray | torch.Tensor
| BaseInstance3DBoxes
):
The converted box of the same type.
The converted box of the same type.
"""
"""
if
src
==
dst
:
if
src
==
dst
:
...
...
mmdet3d/core/bbox/structures/cam_box3d.py
View file @
679741d0
...
@@ -240,3 +240,22 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
...
@@ -240,3 +240,22 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
lowest_of_top
=
torch
.
max
(
boxes1_top_height
,
boxes2_top_height
)
lowest_of_top
=
torch
.
max
(
boxes1_top_height
,
boxes2_top_height
)
overlaps_h
=
torch
.
clamp
(
heighest_of_bottom
-
lowest_of_top
,
min
=
0
)
overlaps_h
=
torch
.
clamp
(
heighest_of_bottom
-
lowest_of_top
,
min
=
0
)
return
overlaps_h
return
overlaps_h
def
convert_to
(
self
,
dst
,
rt_mat
=
None
):
"""Convert self to `dst` mode.
Args:
dst (BoxMode): the target Box mode
rt_mat (np.ndarray | torch.Tensor): The rotation and translation
matrix between different coordinates. Defaults to None.
The conversion from `src` coordinates to `dst` coordinates
usually comes along the change of sensors, e.g., from camera
to LiDAR. This requires a transformation matrix.
Returns:
BaseInstance3DBoxes:
The converted box of the same type in the `dst` mode.
"""
from
.box_3d_mode
import
Box3DMode
return
Box3DMode
.
convert
(
box
=
self
,
src
=
Box3DMode
.
CAM
,
dst
=
dst
,
rt_mat
=
rt_mat
)
mmdet3d/core/bbox/structures/depth_box3d.py
View file @
679741d0
...
@@ -182,3 +182,22 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
...
@@ -182,3 +182,22 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
&
(
self
.
tensor
[:,
0
]
<
box_range
[
2
])
&
(
self
.
tensor
[:,
0
]
<
box_range
[
2
])
&
(
self
.
tensor
[:,
1
]
<
box_range
[
3
]))
&
(
self
.
tensor
[:,
1
]
<
box_range
[
3
]))
return
in_range_flags
return
in_range_flags
def
convert_to
(
self
,
dst
,
rt_mat
=
None
):
"""Convert self to `dst` mode.
Args:
dst (BoxMode): the target Box mode
rt_mat (np.ndarray | torch.Tensor): The rotation and translation
matrix between different coordinates. Defaults to None.
The conversion from `src` coordinates to `dst` coordinates
usually comes along the change of sensors, e.g., from camera
to LiDAR. This requires a transformation matrix.
Returns:
BaseInstance3DBoxes:
The converted box of the same type in the `dst` mode.
"""
from
.box_3d_mode
import
Box3DMode
return
Box3DMode
.
convert
(
box
=
self
,
src
=
Box3DMode
.
DEPTH
,
dst
=
dst
,
rt_mat
=
rt_mat
)
mmdet3d/core/bbox/structures/lidar_box3d.py
View file @
679741d0
...
@@ -133,6 +133,10 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
...
@@ -133,6 +133,10 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
self
.
tensor
[:,
:
3
]
=
self
.
tensor
[:,
:
3
]
@
rot_mat_T
self
.
tensor
[:,
:
3
]
=
self
.
tensor
[:,
:
3
]
@
rot_mat_T
self
.
tensor
[:,
6
]
+=
angle
self
.
tensor
[:,
6
]
+=
angle
if
self
.
tensor
.
shape
[
1
]
==
9
:
# rotate velo vector
self
.
tensor
[:,
7
:
9
]
=
self
.
tensor
[:,
7
:
9
]
@
rot_mat_T
[:
2
,
:
2
]
def
flip
(
self
,
bev_direction
=
'horizontal'
):
def
flip
(
self
,
bev_direction
=
'horizontal'
):
"""Flip the boxes in BEV along given BEV direction
"""Flip the boxes in BEV along given BEV direction
...
@@ -173,3 +177,22 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
...
@@ -173,3 +177,22 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
&
(
self
.
tensor
[:,
0
]
<
box_range
[
2
])
&
(
self
.
tensor
[:,
0
]
<
box_range
[
2
])
&
(
self
.
tensor
[:,
1
]
<
box_range
[
3
]))
&
(
self
.
tensor
[:,
1
]
<
box_range
[
3
]))
return
in_range_flags
return
in_range_flags
def
convert_to
(
self
,
dst
,
rt_mat
=
None
):
"""Convert self to `dst` mode.
Args:
dst (BoxMode): the target Box mode
rt_mat (np.ndarray | torch.Tensor): The rotation and translation
matrix between different coordinates. Defaults to None.
The conversion from `src` coordinates to `dst` coordinates
usually comes along the change of sensors, e.g., from camera
to LiDAR. This requires a transformation matrix.
Returns:
BaseInstance3DBoxes:
The converted box of the same type in the `dst` mode.
"""
from
.box_3d_mode
import
Box3DMode
return
Box3DMode
.
convert
(
box
=
self
,
src
=
Box3DMode
.
LIDAR
,
dst
=
dst
,
rt_mat
=
rt_mat
)
mmdet3d/datasets/kitti_dataset.py
View file @
679741d0
...
@@ -9,7 +9,7 @@ import torch
...
@@ -9,7 +9,7 @@ import torch
from
mmcv.utils
import
print_log
from
mmcv.utils
import
print_log
from
mmdet.datasets
import
DATASETS
from
mmdet.datasets
import
DATASETS
from
..core.bbox
import
box_np_ops
from
..core.bbox
import
Box3DMode
,
CameraInstance3DBoxes
,
box_np_ops
from
.custom_3d
import
Custom3DDataset
from
.custom_3d
import
Custom3DDataset
from
.utils
import
remove_dontcare
from
.utils
import
remove_dontcare
...
@@ -87,13 +87,14 @@ class KittiDataset(Custom3DDataset):
...
@@ -87,13 +87,14 @@ class KittiDataset(Custom3DDataset):
# print(gt_names, len(loc))
# print(gt_names, len(loc))
gt_bboxes_3d
=
np
.
concatenate
([
loc
,
dims
,
rots
[...,
np
.
newaxis
]],
gt_bboxes_3d
=
np
.
concatenate
([
loc
,
dims
,
rots
[...,
np
.
newaxis
]],
axis
=
1
).
astype
(
np
.
float32
)
axis
=
1
).
astype
(
np
.
float32
)
# this change gt_bboxes_3d to velodyne coordinates
gt_bboxes_3d
=
box_np_ops
.
box_camera_to_lidar
(
gt_bboxes_3d
,
rect
,
# convert gt_bboxes_3d to velodyne coordinates
Trv2c
)
gt_bboxes_3d
=
CameraInstance3DBoxes
(
gt_bboxes_3d
).
convert_to
(
Box3DMode
.
LIDAR
,
np
.
linalg
.
inv
(
rect
@
Trv2c
))
gt_bboxes
=
annos
[
'bbox'
]
gt_bboxes
=
annos
[
'bbox'
]
selected
=
self
.
drop_arrays_by_name
(
gt_names
,
[
'DontCare'
])
selected
=
self
.
drop_arrays_by_name
(
gt_names
,
[
'DontCare'
])
gt_bboxes_3d
=
gt_bboxes_3d
[
selected
].
astype
(
'float32'
)
#
gt_bboxes_3d = gt_bboxes_3d[selected].astype('float32')
gt_bboxes
=
gt_bboxes
[
selected
].
astype
(
'float32'
)
gt_bboxes
=
gt_bboxes
[
selected
].
astype
(
'float32'
)
gt_names
=
gt_names
[
selected
]
gt_names
=
gt_names
[
selected
]
...
...
mmdet3d/datasets/pipelines/formating.py
View file @
679741d0
import
numpy
as
np
import
numpy
as
np
from
mmcv.parallel
import
DataContainer
as
DC
from
mmcv.parallel
import
DataContainer
as
DC
from
mmdet3d.core.bbox
import
BaseInstance3DBoxes
from
mmdet.datasets.builder
import
PIPELINES
from
mmdet.datasets.builder
import
PIPELINES
from
mmdet.datasets.pipelines
import
to_tensor
from
mmdet.datasets.pipelines
import
to_tensor
...
@@ -39,9 +40,8 @@ class DefaultFormatBundle(object):
...
@@ -39,9 +40,8 @@ class DefaultFormatBundle(object):
img
=
np
.
ascontiguousarray
(
results
[
'img'
].
transpose
(
2
,
0
,
1
))
img
=
np
.
ascontiguousarray
(
results
[
'img'
].
transpose
(
2
,
0
,
1
))
results
[
'img'
]
=
DC
(
to_tensor
(
img
),
stack
=
True
)
results
[
'img'
]
=
DC
(
to_tensor
(
img
),
stack
=
True
)
for
key
in
[
for
key
in
[
'proposals'
,
'gt_bboxes'
,
'gt_bboxes_3d'
,
'gt_bboxes_ignore'
,
'proposals'
,
'gt_bboxes'
,
'gt_bboxes_ignore'
,
'gt_labels'
,
'gt_labels'
,
'gt_labels_3d'
,
'pts_instance_mask'
,
'gt_labels_3d'
,
'pts_instance_mask'
,
'pts_semantic_mask'
'pts_semantic_mask'
]:
]:
if
key
not
in
results
:
if
key
not
in
results
:
continue
continue
...
@@ -49,6 +49,14 @@ class DefaultFormatBundle(object):
...
@@ -49,6 +49,14 @@ class DefaultFormatBundle(object):
results
[
key
]
=
DC
([
to_tensor
(
res
)
for
res
in
results
[
key
]])
results
[
key
]
=
DC
([
to_tensor
(
res
)
for
res
in
results
[
key
]])
else
:
else
:
results
[
key
]
=
DC
(
to_tensor
(
results
[
key
]))
results
[
key
]
=
DC
(
to_tensor
(
results
[
key
]))
if
'gt_bboxes_3d'
in
results
:
if
isinstance
(
results
[
'gt_bboxes_3d'
],
BaseInstance3DBoxes
):
results
[
'gt_bboxes_3d'
]
=
DC
(
results
[
'gt_bboxes_3d'
],
cpu_only
=
True
)
else
:
results
[
'gt_bboxes_3d'
]
=
DC
(
to_tensor
(
results
[
'gt_bboxes_3d'
]))
if
'gt_masks'
in
results
:
if
'gt_masks'
in
results
:
results
[
'gt_masks'
]
=
DC
(
results
[
'gt_masks'
],
cpu_only
=
True
)
results
[
'gt_masks'
]
=
DC
(
results
[
'gt_masks'
],
cpu_only
=
True
)
if
'gt_semantic_seg'
in
results
:
if
'gt_semantic_seg'
in
results
:
...
...
mmdet3d/datasets/pipelines/train_aug.py
View file @
679741d0
...
@@ -26,12 +26,8 @@ class RandomFlip3D(RandomFlip):
...
@@ -26,12 +26,8 @@ class RandomFlip3D(RandomFlip):
self
.
sync_2d
=
sync_2d
self
.
sync_2d
=
sync_2d
def
random_flip_points
(
self
,
gt_bboxes_3d
,
points
):
def
random_flip_points
(
self
,
gt_bboxes_3d
,
points
):
gt_bboxes_3d
[:,
1
]
=
-
gt_bboxes_3d
[:,
1
]
gt_bboxes_3d
.
flip
()
gt_bboxes_3d
[:,
6
]
=
-
gt_bboxes_3d
[:,
6
]
+
np
.
pi
points
[:,
1
]
=
-
points
[:,
1
]
points
[:,
1
]
=
-
points
[:,
1
]
if
gt_bboxes_3d
.
shape
[
1
]
==
9
:
# flip velocitys at the same time
gt_bboxes_3d
[:,
8
]
=
-
gt_bboxes_3d
[:,
8
]
return
gt_bboxes_3d
,
points
return
gt_bboxes_3d
,
points
def
__call__
(
self
,
input_dict
):
def
__call__
(
self
,
input_dict
):
...
@@ -121,10 +117,13 @@ class ObjectSample(object):
...
@@ -121,10 +117,13 @@ class ObjectSample(object):
gt_bboxes_2d
=
input_dict
[
'gt_bboxes'
]
gt_bboxes_2d
=
input_dict
[
'gt_bboxes'
]
# Assume for now 3D & 2D bboxes are the same
# Assume for now 3D & 2D bboxes are the same
sampled_dict
=
self
.
db_sampler
.
sample_all
(
sampled_dict
=
self
.
db_sampler
.
sample_all
(
gt_bboxes_3d
,
gt_labels_3d
,
gt_bboxes_2d
=
gt_bboxes_2d
,
img
=
img
)
gt_bboxes_3d
.
tensor
.
numpy
(),
gt_labels_3d
,
gt_bboxes_2d
=
gt_bboxes_2d
,
img
=
img
)
else
:
else
:
sampled_dict
=
self
.
db_sampler
.
sample_all
(
sampled_dict
=
self
.
db_sampler
.
sample_all
(
gt_bboxes_3d
,
gt_labels_3d
,
img
=
None
)
gt_bboxes_3d
.
tensor
.
numpy
()
,
gt_labels_3d
,
img
=
None
)
if
sampled_dict
is
not
None
:
if
sampled_dict
is
not
None
:
sampled_gt_bboxes_3d
=
sampled_dict
[
'gt_bboxes_3d'
]
sampled_gt_bboxes_3d
=
sampled_dict
[
'gt_bboxes_3d'
]
...
@@ -133,8 +132,9 @@ class ObjectSample(object):
...
@@ -133,8 +132,9 @@ class ObjectSample(object):
gt_labels_3d
=
np
.
concatenate
([
gt_labels_3d
,
sampled_gt_labels
],
gt_labels_3d
=
np
.
concatenate
([
gt_labels_3d
,
sampled_gt_labels
],
axis
=
0
)
axis
=
0
)
gt_bboxes_3d
=
np
.
concatenate
([
gt_bboxes_3d
,
sampled_gt_bboxes_3d
gt_bboxes_3d
=
gt_bboxes_3d
.
new_box
(
]).
astype
(
np
.
float32
)
np
.
concatenate
(
[
gt_bboxes_3d
.
tensor
.
numpy
(),
sampled_gt_bboxes_3d
]))
points
=
self
.
remove_points_in_boxes
(
points
,
sampled_gt_bboxes_3d
)
points
=
self
.
remove_points_in_boxes
(
points
,
sampled_gt_bboxes_3d
)
# check the points dimension
# check the points dimension
...
@@ -178,14 +178,16 @@ class ObjectNoise(object):
...
@@ -178,14 +178,16 @@ class ObjectNoise(object):
points
=
input_dict
[
'points'
]
points
=
input_dict
[
'points'
]
# TODO: check this inplace function
# TODO: check this inplace function
numpy_box
=
gt_bboxes_3d
.
tensor
.
numpy
()
noise_per_object_v3_
(
noise_per_object_v3_
(
gt
_
b
box
es_3d
,
numpy
_box
,
points
,
points
,
rotation_perturb
=
self
.
rot_uniform_noise
,
rotation_perturb
=
self
.
rot_uniform_noise
,
center_noise_std
=
self
.
loc_noise_std
,
center_noise_std
=
self
.
loc_noise_std
,
global_random_rot_range
=
self
.
global_rot_range
,
global_random_rot_range
=
self
.
global_rot_range
,
num_try
=
self
.
num_try
)
num_try
=
self
.
num_try
)
input_dict
[
'gt_bboxes_3d'
]
=
gt_bboxes_3d
.
astype
(
'float32'
)
input_dict
[
'gt_bboxes_3d'
]
=
gt_bboxes_3d
.
new_box
(
numpy_box
)
input_dict
[
'points'
]
=
points
input_dict
[
'points'
]
=
points
return
input_dict
return
input_dict
...
@@ -212,7 +214,7 @@ class GlobalRotScale(object):
...
@@ -212,7 +214,7 @@ class GlobalRotScale(object):
def
_trans_bbox_points
(
self
,
gt_boxes
,
points
):
def
_trans_bbox_points
(
self
,
gt_boxes
,
points
):
noise_trans
=
np
.
random
.
normal
(
0
,
self
.
trans_normal_noise
[
0
],
3
).
T
noise_trans
=
np
.
random
.
normal
(
0
,
self
.
trans_normal_noise
[
0
],
3
).
T
points
[:,
:
3
]
+=
noise_trans
points
[:,
:
3
]
+=
noise_trans
gt_boxes
[:,
:
3
]
+=
noise_trans
gt_boxes
.
translate
(
noise_trans
)
return
gt_boxes
,
points
,
noise_trans
return
gt_boxes
,
points
,
noise_trans
def
_rot_bbox_points
(
self
,
gt_boxes
,
points
,
rotation
=
np
.
pi
/
4
):
def
_rot_bbox_points
(
self
,
gt_boxes
,
points
,
rotation
=
np
.
pi
/
4
):
...
@@ -221,16 +223,8 @@ class GlobalRotScale(object):
...
@@ -221,16 +223,8 @@ class GlobalRotScale(object):
noise_rotation
=
np
.
random
.
uniform
(
rotation
[
0
],
rotation
[
1
])
noise_rotation
=
np
.
random
.
uniform
(
rotation
[
0
],
rotation
[
1
])
points
[:,
:
3
],
rot_mat_T
=
box_np_ops
.
rotation_points_single_angle
(
points
[:,
:
3
],
rot_mat_T
=
box_np_ops
.
rotation_points_single_angle
(
points
[:,
:
3
],
noise_rotation
,
axis
=
2
)
points
[:,
:
3
],
noise_rotation
,
axis
=
2
)
gt_boxes
[:,
:
3
],
_
=
box_np_ops
.
rotation_points_single_angle
(
gt_boxes
.
rotate
(
noise_rotation
)
gt_boxes
[:,
:
3
],
noise_rotation
,
axis
=
2
)
gt_boxes
[:,
6
]
+=
noise_rotation
if
gt_boxes
.
shape
[
1
]
==
9
:
# rotate velo vector
rot_cos
=
np
.
cos
(
noise_rotation
)
rot_sin
=
np
.
sin
(
noise_rotation
)
rot_mat_T_bev
=
np
.
array
([[
rot_cos
,
-
rot_sin
],
[
rot_sin
,
rot_cos
]],
dtype
=
points
.
dtype
)
gt_boxes
[:,
7
:
9
]
=
gt_boxes
[:,
7
:
9
]
@
rot_mat_T_bev
return
gt_boxes
,
points
,
rot_mat_T
return
gt_boxes
,
points
,
rot_mat_T
def
_scale_bbox_points
(
self
,
def
_scale_bbox_points
(
self
,
...
@@ -240,9 +234,7 @@ class GlobalRotScale(object):
...
@@ -240,9 +234,7 @@ class GlobalRotScale(object):
max_scale
=
1.05
):
max_scale
=
1.05
):
noise_scale
=
np
.
random
.
uniform
(
min_scale
,
max_scale
)
noise_scale
=
np
.
random
.
uniform
(
min_scale
,
max_scale
)
points
[:,
:
3
]
*=
noise_scale
points
[:,
:
3
]
*=
noise_scale
gt_boxes
[:,
:
6
]
*=
noise_scale
gt_boxes
.
scale
(
noise_scale
)
if
gt_boxes
.
shape
[
1
]
==
9
:
gt_boxes
[:,
7
:]
*=
noise_scale
return
gt_boxes
,
points
,
noise_scale
return
gt_boxes
,
points
,
noise_scale
def
__call__
(
self
,
input_dict
):
def
__call__
(
self
,
input_dict
):
...
@@ -256,7 +248,7 @@ class GlobalRotScale(object):
...
@@ -256,7 +248,7 @@ class GlobalRotScale(object):
gt_bboxes_3d
,
points
,
trans_factor
=
self
.
_trans_bbox_points
(
gt_bboxes_3d
,
points
,
trans_factor
=
self
.
_trans_bbox_points
(
gt_bboxes_3d
,
points
)
gt_bboxes_3d
,
points
)
input_dict
[
'gt_bboxes_3d'
]
=
gt_bboxes_3d
.
astype
(
'float32'
)
input_dict
[
'gt_bboxes_3d'
]
=
gt_bboxes_3d
input_dict
[
'points'
]
=
points
input_dict
[
'points'
]
=
points
input_dict
[
'pcd_scale_factor'
]
=
scale_factor
input_dict
[
'pcd_scale_factor'
]
=
scale_factor
input_dict
[
'pcd_rotation'
]
=
rotation_factor
input_dict
[
'pcd_rotation'
]
=
rotation_factor
...
@@ -290,10 +282,6 @@ class ObjectRangeFilter(object):
...
@@ -290,10 +282,6 @@ class ObjectRangeFilter(object):
self
.
pcd_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
self
.
pcd_range
=
np
.
array
(
point_cloud_range
,
dtype
=
np
.
float32
)
self
.
bev_range
=
self
.
pcd_range
[[
0
,
1
,
3
,
4
]]
self
.
bev_range
=
self
.
pcd_range
[[
0
,
1
,
3
,
4
]]
@
staticmethod
def
limit_period
(
val
,
offset
=
0.5
,
period
=
np
.
pi
):
return
val
-
np
.
floor
(
val
/
period
+
offset
)
*
period
@
staticmethod
@
staticmethod
def
filter_gt_box_outside_range
(
gt_bboxes_3d
,
limit_range
):
def
filter_gt_box_outside_range
(
gt_bboxes_3d
,
limit_range
):
"""remove gtbox outside training range.
"""remove gtbox outside training range.
...
@@ -314,14 +302,13 @@ class ObjectRangeFilter(object):
...
@@ -314,14 +302,13 @@ class ObjectRangeFilter(object):
def
__call__
(
self
,
input_dict
):
def
__call__
(
self
,
input_dict
):
gt_bboxes_3d
=
input_dict
[
'gt_bboxes_3d'
]
gt_bboxes_3d
=
input_dict
[
'gt_bboxes_3d'
]
gt_labels_3d
=
input_dict
[
'gt_labels_3d'
]
gt_labels_3d
=
input_dict
[
'gt_labels_3d'
]
mask
=
self
.
filter_
gt_box
_outside_range
(
gt_bboxes_3d
,
self
.
bev_range
)
mask
=
gt_
b
box
es_3d
.
in_range_bev
(
self
.
bev_range
)
gt_bboxes_3d
=
gt_bboxes_3d
[
mask
]
gt_bboxes_3d
=
gt_bboxes_3d
[
mask
]
gt_labels_3d
=
gt_labels_3d
[
mask
]
gt_labels_3d
=
gt_labels_3d
[
mask
]
# limit rad to [-pi, pi]
# limit rad to [-pi, pi]
gt_bboxes_3d
[:,
6
]
=
self
.
limit_period
(
gt_bboxes_3d
.
limit_yaw
(
offset
=
0.5
,
period
=
2
*
np
.
pi
)
gt_bboxes_3d
[:,
6
],
offset
=
0.5
,
period
=
2
*
np
.
pi
)
input_dict
[
'gt_bboxes_3d'
]
=
gt_bboxes_3d
input_dict
[
'gt_bboxes_3d'
]
=
gt_bboxes_3d
.
astype
(
'float32'
)
input_dict
[
'gt_labels_3d'
]
=
gt_labels_3d
input_dict
[
'gt_labels_3d'
]
=
gt_labels_3d
return
input_dict
return
input_dict
...
...
mmdet3d/models/dense_heads/train_mixins.py
View file @
679741d0
...
@@ -168,6 +168,8 @@ class AnchorTrainMixin(object):
...
@@ -168,6 +168,8 @@ class AnchorTrainMixin(object):
labels
=
anchors
.
new_zeros
(
num_valid_anchors
,
dtype
=
torch
.
long
)
labels
=
anchors
.
new_zeros
(
num_valid_anchors
,
dtype
=
torch
.
long
)
label_weights
=
anchors
.
new_zeros
(
num_valid_anchors
,
dtype
=
torch
.
float
)
label_weights
=
anchors
.
new_zeros
(
num_valid_anchors
,
dtype
=
torch
.
float
)
if
len
(
gt_bboxes
)
>
0
:
if
len
(
gt_bboxes
)
>
0
:
if
not
isinstance
(
gt_bboxes
,
torch
.
Tensor
):
gt_bboxes
=
gt_bboxes
.
tensor
.
to
(
anchors
.
device
)
assign_result
=
bbox_assigner
.
assign
(
anchors
,
gt_bboxes
,
assign_result
=
bbox_assigner
.
assign
(
anchors
,
gt_bboxes
,
gt_bboxes_ignore
,
gt_labels
)
gt_bboxes_ignore
,
gt_labels
)
sampling_result
=
self
.
bbox_sampler
.
sample
(
assign_result
,
anchors
,
sampling_result
=
self
.
bbox_sampler
.
sample
(
assign_result
,
anchors
,
...
...
tests/test_pipeline/test_outdoor_pipeline.py
0 → 100644
View file @
679741d0
import
numpy
as
np
import
torch
from
mmdet3d.core.bbox
import
LiDARInstance3DBoxes
from
mmdet3d.datasets.pipelines
import
Compose
def
test_outdoor_pipeline
():
point_cloud_range
=
[
0
,
-
40
,
-
3
,
70.4
,
40
,
1
]
class_names
=
[
'Car'
]
np
.
random
.
seed
(
0
)
train_pipeline
=
[
dict
(
type
=
'LoadPointsFromFile'
,
load_dim
=
4
,
use_dim
=
4
),
dict
(
type
=
'LoadAnnotations3D'
,
with_bbox_3d
=
True
,
with_label_3d
=
True
),
dict
(
type
=
'ObjectNoise'
,
num_try
=
100
,
loc_noise_std
=
[
1.0
,
1.0
,
0.5
],
global_rot_range
=
[
0.0
,
0.0
],
rot_uniform_noise
=
[
-
0.78539816
,
0.78539816
]),
dict
(
type
=
'RandomFlip3D'
,
flip_ratio
=
0.5
),
dict
(
type
=
'GlobalRotScale'
,
rot_uniform_noise
=
[
-
0.78539816
,
0.78539816
],
scaling_uniform_noise
=
[
0.95
,
1.05
]),
dict
(
type
=
'PointsRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
dict
(
type
=
'ObjectRangeFilter'
,
point_cloud_range
=
point_cloud_range
),
dict
(
type
=
'PointShuffle'
),
dict
(
type
=
'DefaultFormatBundle3D'
,
class_names
=
class_names
),
dict
(
type
=
'Collect3D'
,
keys
=
[
'points'
,
'gt_bboxes_3d'
,
'gt_labels_3d'
])
]
pipeline
=
Compose
(
train_pipeline
)
gt_bboxes_3d
=
LiDARInstance3DBoxes
(
torch
.
tensor
([
[
2.16902428e+01
,
-
4.06038128e-02
,
-
1.61906636e+00
,
1.65999997e+00
,
3.20000005e+00
,
1.61000001e+00
,
-
1.53999996e+00
],
[
7.05006886e+00
,
-
6.57459593e+00
,
-
1.60107934e+00
,
2.27999997e+00
,
1.27799997e+01
,
3.66000009e+00
,
1.54999995e+00
],
[
2.24698811e+01
,
-
6.69203758e+00
,
-
1.50118136e+00
,
2.31999993e+00
,
1.47299995e+01
,
3.64000010e+00
,
1.59000003e+00
],
[
3.48291969e+01
,
-
7.09058380e+00
,
-
1.36622977e+00
,
2.31999993e+00
,
1.00400000e+01
,
3.60999990e+00
,
1.61000001e+00
],
[
4.62394600e+01
,
-
7.75838804e+00
,
-
1.32405007e+00
,
2.33999991e+00
,
1.28299999e+01
,
3.63000011e+00
,
1.63999999e+00
],
[
2.82966995e+01
,
-
5.55755794e-01
,
-
1.30332506e+00
,
1.47000003e+00
,
2.23000002e+00
,
1.48000002e+00
,
-
1.57000005e+00
],
[
2.66690197e+01
,
2.18230209e+01
,
-
1.73605704e+00
,
1.55999994e+00
,
3.48000002e+00
,
1.39999998e+00
,
-
1.69000006e+00
],
[
3.13197803e+01
,
8.16214371e+00
,
-
1.62177873e+00
,
1.74000001e+00
,
3.76999998e+00
,
1.48000002e+00
,
2.78999996e+00
],
[
4.34395561e+01
,
-
1.95209332e+01
,
-
1.20757008e+00
,
1.69000006e+00
,
4.09999990e+00
,
1.40999997e+00
,
-
1.53999996e+00
],
[
3.29882965e+01
,
-
3.79360509e+00
,
-
1.69245458e+00
,
1.74000001e+00
,
4.09000015e+00
,
1.49000001e+00
,
-
1.52999997e+00
],
[
3.85469360e+01
,
8.35060215e+00
,
-
1.31423414e+00
,
1.59000003e+00
,
4.28000021e+00
,
1.45000005e+00
,
1.73000002e+00
],
[
2.22492104e+01
,
-
1.13536005e+01
,
-
1.38272512e+00
,
1.62000000e+00
,
3.55999994e+00
,
1.71000004e+00
,
2.48000002e+00
],
[
3.36115799e+01
,
-
1.97708054e+01
,
-
4.92827654e-01
,
1.64999998e+00
,
3.54999995e+00
,
1.79999995e+00
,
-
1.57000005e+00
],
[
9.85029602e+00
,
-
1.51294518e+00
,
-
1.66834795e+00
,
1.59000003e+00
,
3.17000008e+00
,
1.38999999e+00
,
-
8.39999974e-01
]
],
dtype
=
torch
.
float32
))
gt_labels_3d
=
np
.
array
([
0
,
-
1
,
-
1
,
-
1
,
-
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
])
results
=
dict
(
pts_filename
=
'tests/data/kitti/a.bin'
,
ann_info
=
dict
(
gt_bboxes_3d
=
gt_bboxes_3d
,
gt_labels_3d
=
gt_labels_3d
),
bbox3d_fields
=
[],
)
output
=
pipeline
(
results
)
expected_tensor
=
torch
.
tensor
(
[[
20.6514
,
-
8.8250
,
-
1.0816
,
1.5893
,
3.0637
,
1.5414
,
-
1.9216
],
[
7.9374
,
4.9457
,
-
1.2008
,
2.1829
,
12.2357
,
3.5041
,
1.6629
],
[
20.8115
,
-
2.0273
,
-
1.8893
,
2.2212
,
14.1026
,
3.4850
,
2.6513
],
[
32.3850
,
-
5.2135
,
-
1.1321
,
2.2212
,
9.6124
,
3.4562
,
2.6498
],
[
43.7022
,
-
7.8316
,
-
0.5090
,
2.2403
,
12.2836
,
3.4754
,
2.0146
],
[
25.3300
,
-
9.6670
,
-
1.0855
,
1.4074
,
2.1350
,
1.4170
,
-
0.7141
],
[
16.5414
,
-
29.0583
,
-
0.9768
,
1.4936
,
3.3318
,
1.3404
,
-
0.7153
],
[
24.6548
,
-
18.9226
,
-
1.3567
,
1.6659
,
3.6094
,
1.4170
,
1.3970
],
[
45.8403
,
1.8183
,
-
1.1626
,
1.6180
,
3.9254
,
1.3499
,
-
0.6886
],
[
30.6288
,
-
8.4497
,
-
1.4881
,
1.6659
,
3.9158
,
1.4265
,
-
0.7241
],
[
32.3316
,
-
22.4611
,
-
1.3131
,
1.5223
,
4.0977
,
1.3882
,
2.4186
],
[
22.4492
,
3.2944
,
-
2.1674
,
1.5510
,
3.4084
,
1.6372
,
0.3928
],
[
37.3824
,
5.0472
,
-
0.6579
,
1.5797
,
3.3988
,
1.7233
,
-
1.4862
],
[
8.9259
,
-
1.2578
,
-
1.6081
,
1.5223
,
3.0350
,
1.3308
,
-
1.7212
]])
assert
torch
.
allclose
(
output
[
'gt_bboxes_3d'
].
_data
.
tensor
,
expected_tensor
,
atol
=
1e-3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment