Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
4f1a5e52
Commit
4f1a5e52
authored
May 08, 2020
by
liyinhao
Browse files
Merge branch 'master_temp' into indoor_augment
parents
c2c0f3d8
f584b970
Changes
111
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
311 additions
and
200 deletions
+311
-200
mmdet3d/core/bbox/box_torch_ops.py
mmdet3d/core/bbox/box_torch_ops.py
+24
-4
mmdet3d/core/bbox/coders/__init__.py
mmdet3d/core/bbox/coders/__init__.py
+3
-2
mmdet3d/core/bbox/coders/delta_xywh_bbox_coder.py
mmdet3d/core/bbox/coders/delta_xywh_bbox_coder.py
+9
-54
mmdet3d/core/bbox/iou_calculators/iou3d_calculator.py
mmdet3d/core/bbox/iou_calculators/iou3d_calculator.py
+30
-20
mmdet3d/core/evaluation/kitti_utils/eval.py
mmdet3d/core/evaluation/kitti_utils/eval.py
+13
-15
mmdet3d/core/optimizer/cocktail_constructor.py
mmdet3d/core/optimizer/cocktail_constructor.py
+2
-2
mmdet3d/core/optimizer/cocktail_optimizer.py
mmdet3d/core/optimizer/cocktail_optimizer.py
+1
-1
mmdet3d/datasets/__init__.py
mmdet3d/datasets/__init__.py
+7
-3
mmdet3d/datasets/builder.py
mmdet3d/datasets/builder.py
+2
-1
mmdet3d/datasets/dataset_wrappers.py
mmdet3d/datasets/dataset_wrappers.py
+1
-1
mmdet3d/datasets/kitti2d_dataset.py
mmdet3d/datasets/kitti2d_dataset.py
+1
-1
mmdet3d/datasets/kitti_dataset.py
mmdet3d/datasets/kitti_dataset.py
+81
-35
mmdet3d/datasets/nuscenes2d_dataset.py
mmdet3d/datasets/nuscenes2d_dataset.py
+0
-38
mmdet3d/datasets/nuscenes_dataset.py
mmdet3d/datasets/nuscenes_dataset.py
+2
-2
mmdet3d/datasets/pipelines/__init__.py
mmdet3d/datasets/pipelines/__init__.py
+8
-1
mmdet3d/datasets/pipelines/dbsampler.py
mmdet3d/datasets/pipelines/dbsampler.py
+3
-3
mmdet3d/datasets/pipelines/formating.py
mmdet3d/datasets/pipelines/formating.py
+4
-4
mmdet3d/datasets/pipelines/indoor_sample.py
mmdet3d/datasets/pipelines/indoor_sample.py
+67
-0
mmdet3d/datasets/pipelines/loading.py
mmdet3d/datasets/pipelines/loading.py
+3
-3
mmdet3d/datasets/pipelines/train_aug.py
mmdet3d/datasets/pipelines/train_aug.py
+50
-10
No files found.
mmdet3d/core/bbox/box_torch_ops.py
View file @
4f1a5e52
...
@@ -21,9 +21,9 @@ def corners_nd(dims, origin=0.5):
...
@@ -21,9 +21,9 @@ def corners_nd(dims, origin=0.5):
where x0 < x1, y0 < y1, z0 < z1
where x0 < x1, y0 < y1, z0 < z1
"""
"""
ndim
=
int
(
dims
.
shape
[
1
])
ndim
=
int
(
dims
.
shape
[
1
])
corners_norm
=
np
.
stack
(
corners_norm
=
torch
.
from_numpy
(
np
.
unravel_index
(
np
.
arange
(
2
**
ndim
),
[
2
]
*
ndim
),
np
.
stack
(
np
.
unravel_index
(
np
.
arange
(
2
**
ndim
),
[
2
]
*
ndim
),
axis
=
1
)).
to
(
axis
=
1
).
as
type
(
dims
.
dtype
)
device
=
dims
.
device
,
d
type
=
dims
.
dtype
)
# now corners_norm has format: (2d) x0y0, x0y1, x1y0, x1y1
# now corners_norm has format: (2d) x0y0, x0y1, x1y0, x1y1
# (3d) x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1
# (3d) x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1
# so need to convert to a format which is convenient to do other computing.
# so need to convert to a format which is convenient to do other computing.
...
@@ -34,7 +34,7 @@ def corners_nd(dims, origin=0.5):
...
@@ -34,7 +34,7 @@ def corners_nd(dims, origin=0.5):
corners_norm
=
corners_norm
[[
0
,
1
,
3
,
2
]]
corners_norm
=
corners_norm
[[
0
,
1
,
3
,
2
]]
elif
ndim
==
3
:
elif
ndim
==
3
:
corners_norm
=
corners_norm
[[
0
,
1
,
3
,
2
,
4
,
5
,
7
,
6
]]
corners_norm
=
corners_norm
[[
0
,
1
,
3
,
2
,
4
,
5
,
7
,
6
]]
corners_norm
=
corners_norm
-
np
.
array
(
origin
,
dtype
=
dims
.
dtype
)
corners_norm
=
corners_norm
-
dims
.
new_tensor
(
origin
)
corners
=
dims
.
reshape
([
-
1
,
1
,
ndim
])
*
corners_norm
.
reshape
(
corners
=
dims
.
reshape
([
-
1
,
1
,
ndim
])
*
corners_norm
.
reshape
(
[
1
,
2
**
ndim
,
ndim
])
[
1
,
2
**
ndim
,
ndim
])
return
corners
return
corners
...
@@ -190,3 +190,23 @@ def rotation_2d(points, angles):
...
@@ -190,3 +190,23 @@ def rotation_2d(points, angles):
rot_cos
=
torch
.
cos
(
angles
)
rot_cos
=
torch
.
cos
(
angles
)
rot_mat_T
=
torch
.
stack
([[
rot_cos
,
-
rot_sin
],
[
rot_sin
,
rot_cos
]])
rot_mat_T
=
torch
.
stack
([[
rot_cos
,
-
rot_sin
],
[
rot_sin
,
rot_cos
]])
return
torch
.
einsum
(
'aij,jka->aik'
,
points
,
rot_mat_T
)
return
torch
.
einsum
(
'aij,jka->aik'
,
points
,
rot_mat_T
)
def
enlarge_box3d_lidar
(
boxes3d
,
extra_width
):
"""Enlarge the length, width and height of input boxes
Args:
boxes3d (torch.float32 or numpy.float32): bottom_center with
shape [N, 7], (x, y, z, w, l, h, ry) in LiDAR coords
extra_width (float): a fix number to add
Returns:
torch.float32 or numpy.float32: enlarged boxes
"""
if
isinstance
(
boxes3d
,
np
.
ndarray
):
large_boxes3d
=
boxes3d
.
copy
()
else
:
large_boxes3d
=
boxes3d
.
clone
()
large_boxes3d
[:,
3
:
6
]
+=
extra_width
*
2
large_boxes3d
[:,
2
]
-=
extra_width
# bottom center z minus extra_width
return
large_boxes3d
mmdet3d/core/bbox/coders/__init__.py
View file @
4f1a5e52
from
.box_coder
import
Residual3DBoxCoder
from
mmdet.core.bbox
import
build_bbox_coder
from
.delta_xywh_bbox_coder
import
DeltaXYZWLHRBBoxCoder
__all__
=
[
'
Residual3D
BoxCoder'
]
__all__
=
[
'
build_bbox_coder'
,
'DeltaXYZWLHRB
BoxCoder'
]
mmdet3d/core/bbox/coders/box_coder.py
→
mmdet3d/core/bbox/coders/
delta_xywh_b
box_coder.py
View file @
4f1a5e52
import
numpy
as
np
import
torch
import
torch
from
mmdet.core.bbox
import
BaseBBoxCoder
from
mmdet.core.bbox.builder
import
BBOX_CODERS
class
Residual3DBoxCoder
(
object
):
def
__init__
(
self
,
code_size
=
7
,
mean
=
None
,
std
=
None
):
@
BBOX_CODERS
.
register_module
()
super
().
__init__
()
class
DeltaXYZWLHRBBoxCoder
(
BaseBBoxCoder
):
self
.
code_size
=
code_size
self
.
mean
=
mean
self
.
std
=
std
@
staticmethod
def
encode_np
(
boxes
,
anchors
):
"""
:param boxes: (N, 7) x, y, z, w, l, h, r
:param anchors: (N, 7)
:return:
"""
# need to convert boxes to z-center format
xa
,
ya
,
za
,
wa
,
la
,
ha
,
ra
=
np
.
split
(
anchors
,
7
,
axis
=-
1
)
xg
,
yg
,
zg
,
wg
,
lg
,
hg
,
rg
=
np
.
split
(
boxes
,
7
,
axis
=-
1
)
zg
=
zg
+
hg
/
2
za
=
za
+
ha
/
2
diagonal
=
np
.
sqrt
(
la
**
2
+
wa
**
2
)
# 4.3
xt
=
(
xg
-
xa
)
/
diagonal
yt
=
(
yg
-
ya
)
/
diagonal
zt
=
(
zg
-
za
)
/
ha
# 1.6
lt
=
np
.
log
(
lg
/
la
)
wt
=
np
.
log
(
wg
/
wa
)
ht
=
np
.
log
(
hg
/
ha
)
rt
=
rg
-
ra
return
np
.
concatenate
([
xt
,
yt
,
zt
,
wt
,
lt
,
ht
,
rt
],
axis
=-
1
)
@
staticmethod
def
decode_np
(
box_encodings
,
anchors
):
"""
:param box_encodings: (N, 7) x, y, z, w, l, h, r
:param anchors: (N, 7)
:return:
"""
# need to convert box_encodings to z-bottom format
xa
,
ya
,
za
,
wa
,
la
,
ha
,
ra
=
np
.
split
(
anchors
,
7
,
axis
=-
1
)
xt
,
yt
,
zt
,
wt
,
lt
,
ht
,
rt
=
np
.
split
(
box_encodings
,
7
,
axis
=-
1
)
za
=
za
+
ha
/
2
def
__init__
(
self
,
code_size
=
7
):
diagonal
=
np
.
sqrt
(
la
**
2
+
wa
**
2
)
super
(
DeltaXYZWLHRBBoxCoder
,
self
).
__init__
()
xg
=
xt
*
diagonal
+
xa
self
.
code_size
=
code_size
yg
=
yt
*
diagonal
+
ya
zg
=
zt
*
ha
+
za
lg
=
np
.
exp
(
lt
)
*
la
wg
=
np
.
exp
(
wt
)
*
wa
hg
=
np
.
exp
(
ht
)
*
ha
rg
=
rt
+
ra
zg
=
zg
-
hg
/
2
return
np
.
concatenate
([
xg
,
yg
,
zg
,
wg
,
lg
,
hg
,
rg
],
axis
=-
1
)
@
staticmethod
@
staticmethod
def
encode
_torch
(
anchors
,
boxes
,
means
,
stds
):
def
encode
(
anchors
,
boxes
):
"""
"""
:param boxes: (N, 7+n) x, y, z, w, l, h, r, velo*
:param boxes: (N, 7+n) x, y, z, w, l, h, r, velo*
:param anchors: (N, 7+n)
:param anchors: (N, 7+n)
...
@@ -85,7 +40,7 @@ class Residual3DBoxCoder(object):
...
@@ -85,7 +40,7 @@ class Residual3DBoxCoder(object):
return
torch
.
cat
([
xt
,
yt
,
zt
,
wt
,
lt
,
ht
,
rt
,
*
cts
],
dim
=-
1
)
return
torch
.
cat
([
xt
,
yt
,
zt
,
wt
,
lt
,
ht
,
rt
,
*
cts
],
dim
=-
1
)
@
staticmethod
@
staticmethod
def
decode
_torch
(
anchors
,
box_encodings
,
means
,
stds
):
def
decode
(
anchors
,
box_encodings
):
"""
"""
:param box_encodings: (N, 7 + n) x, y, z, w, l, h, r
:param box_encodings: (N, 7 + n) x, y, z, w, l, h, r
:param anchors: (N, 7)
:param anchors: (N, 7)
...
...
mmdet3d/core/bbox/iou_calculators/iou3d_calculator.py
View file @
4f1a5e52
import
torch
from
mmdet3d.ops.iou3d
import
boxes_iou3d_gpu
from
mmdet3d.ops.iou3d
import
boxes_iou3d_gpu
from
mmdet.core.bbox
import
bbox_overlaps
from
mmdet.core.bbox
import
bbox_overlaps
from
mmdet.core.bbox.iou_calculators.
registry
import
IOU_CALCULATORS
from
mmdet.core.bbox.iou_calculators.
builder
import
IOU_CALCULATORS
from
..
import
box_torch_ops
from
..
import
box_torch_ops
@
IOU_CALCULATORS
.
register_module
@
IOU_CALCULATORS
.
register_module
()
class
BboxOverlapsNearest3D
(
object
):
class
BboxOverlapsNearest3D
(
object
):
"""Nearest 3D IoU Calculator"""
"""Nearest 3D IoU Calculator"""
...
@@ -18,7 +20,7 @@ class BboxOverlapsNearest3D(object):
...
@@ -18,7 +20,7 @@ class BboxOverlapsNearest3D(object):
return
repr_str
return
repr_str
@
IOU_CALCULATORS
.
register_module
@
IOU_CALCULATORS
.
register_module
()
class
BboxOverlaps3D
(
object
):
class
BboxOverlaps3D
(
object
):
"""3D IoU Calculator"""
"""3D IoU Calculator"""
...
@@ -33,18 +35,22 @@ class BboxOverlaps3D(object):
...
@@ -33,18 +35,22 @@ class BboxOverlaps3D(object):
def
bbox_overlaps_nearest_3d
(
bboxes1
,
bboxes2
,
mode
=
'iou'
,
is_aligned
=
False
):
def
bbox_overlaps_nearest_3d
(
bboxes1
,
bboxes2
,
mode
=
'iou'
,
is_aligned
=
False
):
'''
"""Calculate nearest 3D IoU
:param bboxes1: Tensor, shape (N, 7) [x, y, z, h, w, l, ry]?
:param bboxes2: Tensor, shape (M, 7) [x, y, z, h, w, l, ry]?
Args:
:param mode: mode (str): "iou" (intersection over union) or iof
bboxes1: Tensor, shape (N, 7+N) [x, y, z, h, w, l, ry, v]
bboxes2: Tensor, shape (M, 7+N) [x, y, z, h, w, l, ry, v]
mode: mode (str): "iou" (intersection over union) or iof
(intersection over foreground).
(intersection over foreground).
:return: iou: (M, N) not support aligned mode currently
rbboxes: [N, 5(x, y, xdim, ydim, rad)] rotated bboxes
Return:
'''
iou: (M, N) not support aligned mode currently
rbboxes1_bev
=
bboxes1
.
index_select
(
"""
dim
=-
1
,
index
=
bboxes1
.
new_tensor
([
0
,
1
,
3
,
4
,
6
]).
long
())
assert
bboxes1
.
size
(
-
1
)
>=
7
rbboxes2_bev
=
bboxes2
.
index_select
(
assert
bboxes2
.
size
(
-
1
)
>=
7
dim
=-
1
,
index
=
bboxes1
.
new_tensor
([
0
,
1
,
3
,
4
,
6
]).
long
())
column_index1
=
bboxes1
.
new_tensor
([
0
,
1
,
3
,
4
,
6
],
dtype
=
torch
.
long
)
rbboxes1_bev
=
bboxes1
.
index_select
(
dim
=-
1
,
index
=
column_index1
)
rbboxes2_bev
=
bboxes2
.
index_select
(
dim
=-
1
,
index
=
column_index1
)
# Change the bboxes to bev
# Change the bboxes to bev
# box conversion and iou calculation in torch version on CUDA
# box conversion and iou calculation in torch version on CUDA
...
@@ -57,14 +63,18 @@ def bbox_overlaps_nearest_3d(bboxes1, bboxes2, mode='iou', is_aligned=False):
...
@@ -57,14 +63,18 @@ def bbox_overlaps_nearest_3d(bboxes1, bboxes2, mode='iou', is_aligned=False):
def
bbox_overlaps_3d
(
bboxes1
,
bboxes2
,
mode
=
'iou'
):
def
bbox_overlaps_3d
(
bboxes1
,
bboxes2
,
mode
=
'iou'
):
'''
"""Calculate 3D IoU using cuda implementation
:param bboxes1: Tensor, shape (N, 7) [x, y, z, h, w, l, ry]
Args:
:param bboxes2: Tensor, shape (M, 7) [x, y, z, h, w, l, ry]
bboxes1: Tensor, shape (N, 7) [x, y, z, h, w, l, ry]
:param mode: mode (str): "iou" (intersection over union) or
bboxes2: Tensor, shape (M, 7) [x, y, z, h, w, l, ry]
mode: mode (str): "iou" (intersection over union) or
iof (intersection over foreground).
iof (intersection over foreground).
:return: iou: (M, N) not support aligned mode currently
'''
Return:
iou: (M, N) not support aligned mode currently
"""
# TODO: check the input dimension meanings,
# TODO: check the input dimension meanings,
# this is inconsistent with that in bbox_overlaps_nearest_3d
# this is inconsistent with that in bbox_overlaps_nearest_3d
assert
bboxes1
.
size
(
-
1
)
==
bboxes2
.
size
(
-
1
)
==
7
return
boxes_iou3d_gpu
(
bboxes1
,
bboxes2
,
mode
)
return
boxes_iou3d_gpu
(
bboxes1
,
bboxes2
,
mode
)
mmdet3d/core/evaluation/kitti_utils/eval.py
View file @
4f1a5e52
...
@@ -681,7 +681,6 @@ def kitti_eval(gt_annos,
...
@@ -681,7 +681,6 @@ def kitti_eval(gt_annos,
# mAP threshold array: [num_minoverlap, metric, class]
# mAP threshold array: [num_minoverlap, metric, class]
# mAP result: [num_class, num_diff, num_minoverlap]
# mAP result: [num_class, num_diff, num_minoverlap]
curcls_name
=
class_to_name
[
curcls
]
curcls_name
=
class_to_name
[
curcls
]
ret_dict
[
curcls_name
]
=
{}
for
i
in
range
(
min_overlaps
.
shape
[
0
]):
for
i
in
range
(
min_overlaps
.
shape
[
0
]):
# prepare results for print
# prepare results for print
result
+=
(
'{} AP@{:.2f}, {:.2f}, {:.2f}:
\n
'
.
format
(
result
+=
(
'{} AP@{:.2f}, {:.2f}, {:.2f}:
\n
'
.
format
(
...
@@ -702,18 +701,17 @@ def kitti_eval(gt_annos,
...
@@ -702,18 +701,17 @@ def kitti_eval(gt_annos,
# prepare results for logger
# prepare results for logger
for
idx
in
range
(
3
):
for
idx
in
range
(
3
):
postfix
=
'{}_{}'
.
format
(
difficulty
[
idx
],
min_overlaps
[
i
,
idx
,
if
i
==
0
:
j
])
postfix
=
f
'
{
difficulty
[
idx
]
}
_strict'
else
:
postfix
=
f
'
{
difficulty
[
idx
]
}
_loose'
prefix
=
f
'KITTI/
{
curcls_name
}
'
if
mAP3d
is
not
None
:
if
mAP3d
is
not
None
:
ret_dict
[
curcls_name
][
'3D_{}'
.
format
(
postfix
)]
=
mAP3d
[
j
,
ret_dict
[
f
'
{
prefix
}
_3D_
{
postfix
}
'
]
=
mAP3d
[
j
,
idx
,
i
]
idx
,
i
]
if
mAPbev
is
not
None
:
if
mAPbev
is
not
None
:
ret_dict
[
curcls_name
][
'BEV_{}'
.
format
(
postfix
)]
=
mAPbev
[
ret_dict
[
f
'
{
prefix
}
_BEV_
{
postfix
}
'
]
=
mAPbev
[
j
,
idx
,
i
]
j
,
idx
,
i
]
if
mAPbbox
is
not
None
:
if
mAPbbox
is
not
None
:
ret_dict
[
curcls_name
][
'2D_{}'
.
format
(
postfix
)]
=
mAPbbox
[
ret_dict
[
f
'
{
prefix
}
_2D_
{
postfix
}
'
]
=
mAPbbox
[
j
,
idx
,
i
]
j
,
idx
,
i
]
# calculate mAP over all classes if there are multiple classes
# calculate mAP over all classes if there are multiple classes
if
len
(
current_classes
)
>
1
:
if
len
(
current_classes
)
>
1
:
...
@@ -735,14 +733,14 @@ def kitti_eval(gt_annos,
...
@@ -735,14 +733,14 @@ def kitti_eval(gt_annos,
# prepare results for logger
# prepare results for logger
ret_dict
[
'Overall'
]
=
dict
()
ret_dict
[
'Overall'
]
=
dict
()
for
idx
in
range
(
3
):
for
idx
in
range
(
3
):
postfix
=
'{
}'
.
format
(
difficulty
[
idx
]
)
postfix
=
f
'
{
difficulty
[
idx
]
}
'
if
mAP3d
is
not
None
:
if
mAP3d
is
not
None
:
ret_dict
[
'Overall'
][
'3D_{}'
.
format
(
postfix
)
]
=
mAP3d
[
idx
,
0
]
ret_dict
[
f
'KITTI/Overall_3D_
{
postfix
}
'
]
=
mAP3d
[
idx
,
0
]
if
mAPbev
is
not
None
:
if
mAPbev
is
not
None
:
ret_dict
[
'
Overall
'
][
'
BEV_{
}'
.
format
(
postfix
)
]
=
mAPbev
[
idx
,
0
]
ret_dict
[
f
'KITTI/
Overall
_
BEV_
{
postfix
}
'
]
=
mAPbev
[
idx
,
0
]
if
mAPbbox
is
not
None
:
if
mAPbbox
is
not
None
:
ret_dict
[
'Overall'
][
'2D_{}'
.
format
(
postfix
)
]
=
mAPbbox
[
idx
,
0
]
ret_dict
[
f
'KITTI/Overall_2D_
{
postfix
}
'
]
=
mAPbbox
[
idx
,
0
]
print
(
result
)
return
result
,
ret_dict
return
result
,
ret_dict
...
...
mmdet3d/core/optimizer/cocktail_constructor.py
View file @
4f1a5e52
from
mmcv.utils
import
build_from_cfg
from
mmcv.utils
import
build_from_cfg
from
mmdet3d.utils
import
get_root_logger
from
mmdet.core.optimizer
import
OPTIMIZER_BUILDERS
,
OPTIMIZERS
from
mmdet.core.optimizer
import
OPTIMIZER_BUILDERS
,
OPTIMIZERS
from
mmdet.utils
import
get_root_logger
from
.cocktail_optimizer
import
CocktailOptimizer
from
.cocktail_optimizer
import
CocktailOptimizer
@
OPTIMIZER_BUILDERS
.
register_module
@
OPTIMIZER_BUILDERS
.
register_module
()
class
CocktailOptimizerConstructor
(
object
):
class
CocktailOptimizerConstructor
(
object
):
"""Special constructor for cocktail optimizers.
"""Special constructor for cocktail optimizers.
...
...
mmdet3d/core/optimizer/cocktail_optimizer.py
View file @
4f1a5e52
...
@@ -3,7 +3,7 @@ from torch.optim import Optimizer
...
@@ -3,7 +3,7 @@ from torch.optim import Optimizer
from
mmdet.core.optimizer
import
OPTIMIZERS
from
mmdet.core.optimizer
import
OPTIMIZERS
@
OPTIMIZERS
.
register_module
@
OPTIMIZERS
.
register_module
()
class
CocktailOptimizer
(
Optimizer
):
class
CocktailOptimizer
(
Optimizer
):
"""Cocktail Optimizer that contains multiple optimizers
"""Cocktail Optimizer that contains multiple optimizers
...
...
mmdet3d/datasets/__init__.py
View file @
4f1a5e52
from
mmdet.datasets.
registry
import
DATASETS
from
mmdet.datasets.
builder
import
DATASETS
from
.builder
import
build_dataset
from
.builder
import
build_dataset
from
.dataset_wrappers
import
RepeatFactorDataset
from
.dataset_wrappers
import
RepeatFactorDataset
from
.kitti2d_dataset
import
Kitti2DDataset
from
.kitti2d_dataset
import
Kitti2DDataset
from
.kitti_dataset
import
KittiDataset
from
.kitti_dataset
import
KittiDataset
from
.loader
import
DistributedGroupSampler
,
GroupSampler
,
build_dataloader
from
.loader
import
DistributedGroupSampler
,
GroupSampler
,
build_dataloader
from
.nuscenes2d_dataset
import
NuScenes2DDataset
from
.nuscenes_dataset
import
NuScenesDataset
from
.nuscenes_dataset
import
NuScenesDataset
from
.pipelines
import
(
GlobalRotScale
,
ObjectNoise
,
ObjectRangeFilter
,
ObjectSample
,
PointShuffle
,
PointsRangeFilter
,
RandomFlip3D
)
__all__
=
[
__all__
=
[
'KittiDataset'
,
'GroupSampler'
,
'DistributedGroupSampler'
,
'KittiDataset'
,
'GroupSampler'
,
'DistributedGroupSampler'
,
'build_dataloader'
,
'RepeatFactorDataset'
,
'DATASETS'
,
'build_dataset'
,
'build_dataloader'
,
'RepeatFactorDataset'
,
'DATASETS'
,
'build_dataset'
,
'CocoDataset'
,
'Kitti2DDataset'
,
'NuScenesDataset'
,
'NuScenes2DDataset'
'CocoDataset'
,
'Kitti2DDataset'
,
'NuScenesDataset'
,
'ObjectSample'
,
'RandomFlip3D'
,
'ObjectNoise'
,
'GlobalRotScale'
,
'PointShuffle'
,
'ObjectRangeFilter'
,
'PointsRangeFilter'
,
'Collect3D'
]
]
mmdet3d/datasets/builder.py
View file @
4f1a5e52
import
copy
import
copy
from
mmcv.utils
import
build_from_cfg
from
mmdet.datasets
import
DATASETS
,
ConcatDataset
,
RepeatDataset
from
mmdet.datasets
import
DATASETS
,
ConcatDataset
,
RepeatDataset
from
mmdet.utils
import
build_from_cfg
from
.dataset_wrappers
import
RepeatFactorDataset
from
.dataset_wrappers
import
RepeatFactorDataset
...
...
mmdet3d/datasets/dataset_wrappers.py
View file @
4f1a5e52
...
@@ -7,7 +7,7 @@ from mmdet.datasets import DATASETS
...
@@ -7,7 +7,7 @@ from mmdet.datasets import DATASETS
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
@
DATASETS
.
register_module
@
DATASETS
.
register_module
()
class
RepeatFactorDataset
(
object
):
class
RepeatFactorDataset
(
object
):
"""A wrapper of repeated dataset with repeat factor.
"""A wrapper of repeated dataset with repeat factor.
...
...
mmdet3d/datasets/kitti2d_dataset.py
View file @
4f1a5e52
...
@@ -4,7 +4,7 @@ import numpy as np
...
@@ -4,7 +4,7 @@ import numpy as np
from
mmdet.datasets
import
DATASETS
,
CustomDataset
from
mmdet.datasets
import
DATASETS
,
CustomDataset
@
DATASETS
.
register_module
@
DATASETS
.
register_module
()
class
Kitti2DDataset
(
CustomDataset
):
class
Kitti2DDataset
(
CustomDataset
):
CLASSES
=
(
'car'
,
'pedestrian'
,
'cyclist'
)
CLASSES
=
(
'car'
,
'pedestrian'
,
'cyclist'
)
...
...
mmdet3d/datasets/kitti_dataset.py
View file @
4f1a5e52
import
copy
import
copy
import
os
import
os
import
pickle
import
os.path
as
osp
import
tempfile
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.utils.data
as
torch_data
import
torch.utils.data
as
torch_data
from
mmcv.utils
import
print_log
from
mmdet.datasets
import
DATASETS
from
mmdet.datasets
import
DATASETS
from
mmdet.datasets.pipelines
import
Compose
from
..core.bbox
import
box_np_ops
from
..core.bbox
import
box_np_ops
from
.pipelines
import
Compose
from
.utils
import
remove_dontcare
from
.utils
import
remove_dontcare
@
DATASETS
.
register_module
@
DATASETS
.
register_module
()
class
KittiDataset
(
torch_data
.
Dataset
):
class
KittiDataset
(
torch_data
.
Dataset
):
CLASSES
=
(
'car'
,
'pedestrian'
,
'cyclist'
)
CLASSES
=
(
'car'
,
'pedestrian'
,
'cyclist'
)
...
@@ -43,8 +45,7 @@ class KittiDataset(torch_data.Dataset):
...
@@ -43,8 +45,7 @@ class KittiDataset(torch_data.Dataset):
self
.
pcd_limit_range
=
[
0
,
-
40
,
-
3
,
70.4
,
40
,
0.0
]
self
.
pcd_limit_range
=
[
0
,
-
40
,
-
3
,
70.4
,
40
,
0.0
]
self
.
ann_file
=
ann_file
self
.
ann_file
=
ann_file
with
open
(
ann_file
,
'rb'
)
as
f
:
self
.
kitti_infos
=
mmcv
.
load
(
ann_file
)
self
.
kitti_infos
=
pickle
.
load
(
f
)
# set group flag for the sampler
# set group flag for the sampler
if
not
self
.
test_mode
:
if
not
self
.
test_mode
:
...
@@ -262,37 +263,76 @@ class KittiDataset(torch_data.Dataset):
...
@@ -262,37 +263,76 @@ class KittiDataset(torch_data.Dataset):
inds
=
np
.
array
(
inds
,
dtype
=
np
.
int64
)
inds
=
np
.
array
(
inds
,
dtype
=
np
.
int64
)
return
inds
return
inds
def
reformat_bbox
(
self
,
outputs
,
out
=
None
):
def
format_results
(
self
,
outputs
,
pklfile_prefix
=
None
,
submission_prefix
=
None
):
if
pklfile_prefix
is
None
:
tmp_dir
=
tempfile
.
TemporaryDirectory
()
pklfile_prefix
=
osp
.
join
(
tmp_dir
.
name
,
'results'
)
else
:
tmp_dir
=
None
if
not
isinstance
(
outputs
[
0
][
0
],
dict
):
if
not
isinstance
(
outputs
[
0
][
0
],
dict
):
sample_idx
=
[
sample_idx
=
[
info
[
'image'
][
'image_idx'
]
for
info
in
self
.
kitti_infos
info
[
'image'
][
'image_idx'
]
for
info
in
self
.
kitti_infos
]
]
result_files
=
self
.
bbox2result_kitti2d
(
outputs
,
self
.
class_names
,
result_files
=
self
.
bbox2result_kitti2d
(
outputs
,
self
.
class_names
,
sample_idx
,
out
)
sample_idx
,
pklfile_prefix
,
submission_prefix
)
else
:
else
:
result_files
=
self
.
bbox2result_kitti
(
outputs
,
self
.
class_names
,
result_files
=
self
.
bbox2result_kitti
(
outputs
,
self
.
class_names
,
out
)
pklfile_prefix
,
return
result_files
submission_prefix
)
return
result_files
,
tmp_dir
def
evaluate
(
self
,
results
,
metric
=
None
,
logger
=
None
,
pklfile_prefix
=
None
,
submission_prefix
=
None
,
result_names
=
[
'pts_bbox'
]):
"""Evaluation in KITTI protocol.
def
evaluate
(
self
,
result_files
,
eval_types
=
None
):
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
pklfile_prefix (str | None): The prefix of pkl files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
submission_prefix (str | None): The prefix of submission datas.
If not specified, the submission data will not be generated.
Returns:
dict[str: float]
"""
result_files
,
tmp_dir
=
self
.
format_results
(
results
,
pklfile_prefix
)
from
mmdet3d.core.evaluation
import
kitti_eval
from
mmdet3d.core.evaluation
import
kitti_eval
gt_annos
=
[
info
[
'annos'
]
for
info
in
self
.
kitti_infos
]
gt_annos
=
[
info
[
'annos'
]
for
info
in
self
.
kitti_infos
]
if
eval_types
==
'img_bbox'
:
if
metric
==
'img_bbox'
:
ap_result_str
,
ap_dict
=
kitti_eval
(
ap_result_str
,
ap_dict
=
kitti_eval
(
gt_annos
,
result_files
,
self
.
class_names
,
eval_types
=
[
'bbox'
])
gt_annos
,
result_files
,
self
.
class_names
,
eval_types
=
[
'bbox'
])
else
:
else
:
ap_result_str
,
ap_dict
=
kitti_eval
(
gt_annos
,
result_files
,
ap_result_str
,
ap_dict
=
kitti_eval
(
gt_annos
,
result_files
,
self
.
class_names
)
self
.
class_names
)
return
ap_result_str
,
ap_dict
print_log
(
'
\n
'
+
ap_result_str
,
logger
=
logger
)
if
tmp_dir
is
not
None
:
def
bbox2result_kitti
(
self
,
net_outputs
,
class_names
,
out
=
None
):
tmp_dir
.
cleanup
()
if
out
:
return
ap_dict
output_dir
=
out
[:
-
4
]
if
out
.
endswith
((
'.pkl'
,
'.pickle'
))
else
out
result_dir
=
output_dir
+
'/data'
def
bbox2result_kitti
(
self
,
mmcv
.
mkdir_or_exist
(
result_dir
)
net_outputs
,
class_names
,
pklfile_prefix
=
None
,
submission_prefix
=
None
):
if
submission_prefix
is
not
None
:
mmcv
.
mkdir_or_exist
(
submission_prefix
)
det_annos
=
[]
det_annos
=
[]
print
(
'Converting prediction to KITTI format'
)
print
(
'
\n
Converting prediction to KITTI format'
)
for
idx
,
pred_dicts
in
enumerate
(
for
idx
,
pred_dicts
in
enumerate
(
mmcv
.
track_iter_progress
(
net_outputs
)):
mmcv
.
track_iter_progress
(
net_outputs
)):
annos
=
[]
annos
=
[]
...
@@ -346,9 +386,9 @@ class KittiDataset(torch_data.Dataset):
...
@@ -346,9 +386,9 @@ class KittiDataset(torch_data.Dataset):
anno
=
{
k
:
np
.
stack
(
v
)
for
k
,
v
in
anno
.
items
()}
anno
=
{
k
:
np
.
stack
(
v
)
for
k
,
v
in
anno
.
items
()}
annos
.
append
(
anno
)
annos
.
append
(
anno
)
if
out
:
if
submission_prefix
is
not
None
:
cur
_det
_file
=
result_dir
+
'/%06d.txt'
%
sample_idx
cur
r
_file
=
f
'
{
submission_prefix
}
/
{
sample_idx
:
06
d
}
.txt'
with
open
(
cur
_det
_file
,
'w'
)
as
f
:
with
open
(
cur
r
_file
,
'w'
)
as
f
:
bbox
=
anno
[
'bbox'
]
bbox
=
anno
[
'bbox'
]
loc
=
anno
[
'location'
]
loc
=
anno
[
'location'
]
dims
=
anno
[
'dimensions'
]
# lhw -> hwl
dims
=
anno
[
'dimensions'
]
# lhw -> hwl
...
@@ -386,9 +426,9 @@ class KittiDataset(torch_data.Dataset):
...
@@ -386,9 +426,9 @@ class KittiDataset(torch_data.Dataset):
det_annos
+=
annos
det_annos
+=
annos
if
out
:
if
pklfile_prefix
is
not
None
:
if
not
out
.
endswith
((
'.pkl'
,
'.pickle'
)):
if
not
pklfile_prefix
.
endswith
((
'.pkl'
,
'.pickle'
)):
out
=
'{
}.pkl'
.
format
(
out
)
out
=
f
'
{
pklfile_prefix
}
.pkl'
mmcv
.
dump
(
det_annos
,
out
)
mmcv
.
dump
(
det_annos
,
out
)
print
(
'Result is saved to %s'
%
out
)
print
(
'Result is saved to %s'
%
out
)
...
@@ -398,7 +438,8 @@ class KittiDataset(torch_data.Dataset):
...
@@ -398,7 +438,8 @@ class KittiDataset(torch_data.Dataset):
net_outputs
,
net_outputs
,
class_names
,
class_names
,
sample_ids
,
sample_ids
,
out
=
None
):
pklfile_prefix
=
None
,
submission_prefix
=
None
):
"""Convert results to kitti format for evaluation and test submission
"""Convert results to kitti format for evaluation and test submission
Args:
Args:
...
@@ -406,6 +447,8 @@ class KittiDataset(torch_data.Dataset):
...
@@ -406,6 +447,8 @@ class KittiDataset(torch_data.Dataset):
class_nanes (List[String]): A list of class names
class_nanes (List[String]): A list of class names
sample_idx (List[Int]): A list of samples' index,
sample_idx (List[Int]): A list of samples' index,
should have the same length as net_outputs.
should have the same length as net_outputs.
pklfile_prefix (str | None): The prefix of pkl file.
submission_prefix (str | None): The prefix of submission file.
Return:
Return:
List([dict]): A list of dict have the kitti format
List([dict]): A list of dict have the kitti format
...
@@ -469,17 +512,20 @@ class KittiDataset(torch_data.Dataset):
...
@@ -469,17 +512,20 @@ class KittiDataset(torch_data.Dataset):
[
sample_idx
]
*
num_example
,
dtype
=
np
.
int64
)
[
sample_idx
]
*
num_example
,
dtype
=
np
.
int64
)
det_annos
+=
annos
det_annos
+=
annos
if
out
:
if
pklfile_prefix
is
not
None
:
# save file in pkl format
pklfile_path
=
(
pklfile_prefix
[:
-
4
]
if
pklfile_prefix
.
endswith
(
(
'.pkl'
,
'.pickle'
))
else
pklfile_prefix
)
mmcv
.
dump
(
det_annos
,
pklfile_path
)
if
submission_prefix
is
not
None
:
# save file in submission format
# save file in submission format
output_dir
=
out
[:
-
4
]
if
out
.
endswith
((
'.pkl'
,
'.pickle'
))
else
out
mmcv
.
mkdir_or_exist
(
submission_prefix
)
result_dir
=
output_dir
+
'/data'
print
(
f
'Saving KITTI submission to
{
submission_prefix
}
'
)
mmcv
.
mkdir_or_exist
(
result_dir
)
out
=
'{}.pkl'
.
format
(
result_dir
)
mmcv
.
dump
(
det_annos
,
out
)
print
(
'Result is saved to {}'
.
format
(
out
))
for
i
,
anno
in
enumerate
(
det_annos
):
for
i
,
anno
in
enumerate
(
det_annos
):
sample_idx
=
sample_ids
[
i
]
sample_idx
=
sample_ids
[
i
]
cur_det_file
=
result_dir
+
'/%06d.txt'
%
sample_idx
cur_det_file
=
f
'
{
submission_prefix
}
/
{
sample_idx
:
06
d
}
.txt'
with
open
(
cur_det_file
,
'w'
)
as
f
:
with
open
(
cur_det_file
,
'w'
)
as
f
:
bbox
=
anno
[
'bbox'
]
bbox
=
anno
[
'bbox'
]
loc
=
anno
[
'location'
]
loc
=
anno
[
'location'
]
...
@@ -497,7 +543,7 @@ class KittiDataset(torch_data.Dataset):
...
@@ -497,7 +543,7 @@ class KittiDataset(torch_data.Dataset):
anno
[
'score'
][
idx
]),
anno
[
'score'
][
idx
]),
file
=
f
,
file
=
f
,
)
)
print
(
'Result is saved to {}'
.
format
(
result_dir
))
print
(
'Result is saved to {}'
.
format
(
submission_prefix
))
return
det_annos
return
det_annos
...
...
mmdet3d/datasets/nuscenes2d_dataset.py
deleted
100644 → 0
View file @
c2c0f3d8
from
pycocotools.coco
import
COCO
from
mmdet3d.core.evaluation.coco_utils
import
getImgIds
from
mmdet.datasets
import
DATASETS
,
CocoDataset
@
DATASETS
.
register_module
class
NuScenes2DDataset
(
CocoDataset
):
CLASSES
=
(
'car'
,
'truck'
,
'trailer'
,
'bus'
,
'construction_vehicle'
,
'bicycle'
,
'motorcycle'
,
'pedestrian'
,
'traffic_cone'
,
'barrier'
)
def
load_annotations
(
self
,
ann_file
):
if
not
self
.
class_names
:
self
.
class_names
=
self
.
CLASSES
self
.
coco
=
COCO
(
ann_file
)
# send class_names into the get id
# in case we only need to train on several classes
# by default self.class_names = CLASSES
self
.
cat_ids
=
self
.
coco
.
getCatIds
(
catNms
=
self
.
class_names
)
self
.
cat2label
=
{
cat_id
:
i
# + 1 rm +1 here thus the 0-79 are fg, 80 is bg
for
i
,
cat_id
in
enumerate
(
self
.
cat_ids
)
}
# send cat ids to the get img id
# in case we only need to train on several classes
if
len
(
self
.
cat_ids
)
<
len
(
self
.
CLASSES
):
self
.
img_ids
=
getImgIds
(
self
.
coco
,
catIds
=
self
.
cat_ids
)
else
:
self
.
img_ids
=
self
.
coco
.
getImgIds
()
img_infos
=
[]
for
i
in
self
.
img_ids
:
info
=
self
.
coco
.
loadImgs
([
i
])[
0
]
info
[
'filename'
]
=
info
[
'file_name'
]
img_infos
.
append
(
info
)
return
img_infos
mmdet3d/datasets/nuscenes_dataset.py
View file @
4f1a5e52
...
@@ -9,11 +9,11 @@ import torch.utils.data as torch_data
...
@@ -9,11 +9,11 @@ import torch.utils.data as torch_data
from
nuscenes.utils.data_classes
import
Box
as
NuScenesBox
from
nuscenes.utils.data_classes
import
Box
as
NuScenesBox
from
mmdet.datasets
import
DATASETS
from
mmdet.datasets
import
DATASETS
from
mmdet.datasets.pipelines
import
Compose
from
..core.bbox
import
box_np_ops
from
..core.bbox
import
box_np_ops
from
.pipelines
import
Compose
@
DATASETS
.
register_module
@
DATASETS
.
register_module
()
class
NuScenesDataset
(
torch_data
.
Dataset
):
class
NuScenesDataset
(
torch_data
.
Dataset
):
NumPointFeatures
=
4
# xyz, timestamp. set 4 to use kitti pretrain
NumPointFeatures
=
4
# xyz, timestamp. set 4 to use kitti pretrain
NameMapping
=
{
NameMapping
=
{
...
...
mmdet3d/datasets/pipelines/__init__.py
View file @
4f1a5e52
from
mmdet.datasets.pipelines
import
Compose
from
.dbsampler
import
DataBaseSampler
,
MMDataBaseSampler
from
.formating
import
DefaultFormatBundle
,
DefaultFormatBundle3D
from
.loading
import
LoadMultiViewImageFromFiles
,
LoadPointsFromFile
from
.train_aug
import
(
GlobalRotScale
,
ObjectNoise
,
ObjectRangeFilter
,
from
.train_aug
import
(
GlobalRotScale
,
ObjectNoise
,
ObjectRangeFilter
,
ObjectSample
,
PointShuffle
,
PointsRangeFilter
,
ObjectSample
,
PointShuffle
,
PointsRangeFilter
,
RandomFlip3D
)
RandomFlip3D
)
__all__
=
[
__all__
=
[
'ObjectSample'
,
'RandomFlip3D'
,
'ObjectNoise'
,
'GlobalRotScale'
,
'ObjectSample'
,
'RandomFlip3D'
,
'ObjectNoise'
,
'GlobalRotScale'
,
'PointShuffle'
,
'ObjectRangeFilter'
,
'PointsRangeFilter'
,
'Collect3D'
'PointShuffle'
,
'ObjectRangeFilter'
,
'PointsRangeFilter'
,
'Collect3D'
,
'Compose'
,
'LoadMultiViewImageFromFiles'
,
'LoadPointsFromFile'
,
'DefaultFormatBundle'
,
'DefaultFormatBundle3D'
,
'DataBaseSampler'
,
'MMDataBaseSampler'
]
]
mmdet3d/datasets/pipelines/dbsampler.py
View file @
4f1a5e52
...
@@ -52,7 +52,7 @@ class BatchSampler:
...
@@ -52,7 +52,7 @@ class BatchSampler:
return
[
self
.
_sampled_list
[
i
]
for
i
in
indices
]
return
[
self
.
_sampled_list
[
i
]
for
i
in
indices
]
@
OBJECTSAMPLERS
.
register_module
@
OBJECTSAMPLERS
.
register_module
()
class
DataBaseSampler
(
object
):
class
DataBaseSampler
(
object
):
def
__init__
(
self
,
info_path
,
root_path
,
rate
,
prepare
,
object_rot_range
,
def
__init__
(
self
,
info_path
,
root_path
,
rate
,
prepare
,
object_rot_range
,
...
@@ -68,7 +68,7 @@ class DataBaseSampler(object):
...
@@ -68,7 +68,7 @@ class DataBaseSampler(object):
db_infos
=
pickle
.
load
(
f
)
db_infos
=
pickle
.
load
(
f
)
# filter database infos
# filter database infos
from
mmdet3d.
api
s
import
get_root_logger
from
mmdet3d.
util
s
import
get_root_logger
logger
=
get_root_logger
()
logger
=
get_root_logger
()
for
k
,
v
in
db_infos
.
items
():
for
k
,
v
in
db_infos
.
items
():
logger
.
info
(
f
'load
{
len
(
v
)
}
{
k
}
database infos'
)
logger
.
info
(
f
'load
{
len
(
v
)
}
{
k
}
database infos'
)
...
@@ -255,7 +255,7 @@ class DataBaseSampler(object):
...
@@ -255,7 +255,7 @@ class DataBaseSampler(object):
return
valid_samples
return
valid_samples
@
OBJECTSAMPLERS
.
register_module
@
OBJECTSAMPLERS
.
register_module
()
class
MMDataBaseSampler
(
DataBaseSampler
):
class
MMDataBaseSampler
(
DataBaseSampler
):
def
__init__
(
self
,
def
__init__
(
self
,
...
...
mmdet3d/datasets/pipelines/formating.py
View file @
4f1a5e52
import
numpy
as
np
import
numpy
as
np
from
mmcv.parallel
import
DataContainer
as
DC
from
mmcv.parallel
import
DataContainer
as
DC
from
mmdet.datasets.builder
import
PIPELINES
from
mmdet.datasets.pipelines
import
to_tensor
from
mmdet.datasets.pipelines
import
to_tensor
from
mmdet.datasets.registry
import
PIPELINES
PIPELINES
.
_module_dict
.
pop
(
'DefaultFormatBundle'
)
PIPELINES
.
_module_dict
.
pop
(
'DefaultFormatBundle'
)
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
DefaultFormatBundle
(
object
):
class
DefaultFormatBundle
(
object
):
"""Default formatting bundle.
"""Default formatting bundle.
...
@@ -59,7 +59,7 @@ class DefaultFormatBundle(object):
...
@@ -59,7 +59,7 @@ class DefaultFormatBundle(object):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
Collect3D
(
object
):
class
Collect3D
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -90,7 +90,7 @@ class Collect3D(object):
...
@@ -90,7 +90,7 @@ class Collect3D(object):
self
.
keys
,
self
.
meta_keys
)
self
.
keys
,
self
.
meta_keys
)
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
DefaultFormatBundle3D
(
DefaultFormatBundle
):
class
DefaultFormatBundle3D
(
DefaultFormatBundle
):
"""Default formatting bundle.
"""Default formatting bundle.
...
...
mmdet3d/datasets/pipelines/indoor_sample.py
0 → 100644
View file @
4f1a5e52
import
numpy
as
np
from
mmdet.datasets.builder
import
PIPELINES
@
PIPELINES
.
register_module
()
class
PointSample
(
object
):
"""Point Sample.
Sampling data to a certain number.
Args:
name (str): Name of the dataset.
num_points (int): Number of points to be sampled.
"""
def
__init__
(
self
,
num_points
):
self
.
num_points
=
num_points
def
points_random_sampling
(
self
,
points
,
num_samples
,
replace
=
None
,
return_choices
=
False
):
"""Points Random Sampling.
Sample points to a certain number.
Args:
points (ndarray): 3D Points.
num_samples (int): Number of samples to be sampled.
replace (bool): Whether the sample is with or without replacement.
return_choices (bool): Whether return choice.
Returns:
points (ndarray): 3D Points.
choices (ndarray): The generated random samples
"""
if
replace
is
None
:
replace
=
(
points
.
shape
[
0
]
<
num_samples
)
choices
=
np
.
random
.
choice
(
points
.
shape
[
0
],
num_samples
,
replace
=
replace
)
if
return_choices
:
return
points
[
choices
],
choices
else
:
return
points
[
choices
]
def
__call__
(
self
,
results
):
points
=
results
.
get
(
'points'
,
None
)
points
,
choices
=
self
.
points_random_sampling
(
points
,
self
.
num_points
,
return_choices
=
True
)
pts_instance_mask
=
results
.
get
(
'pts_instance_mask'
,
None
)
pts_semantic_mask
=
results
.
get
(
'pts_semantic_mask'
,
None
)
results
[
'points'
]
=
points
if
pts_instance_mask
is
not
None
and
pts_semantic_mask
is
not
None
:
pts_instance_mask
=
pts_instance_mask
[
choices
]
pts_semantic_mask
=
pts_semantic_mask
[
choices
]
results
[
'pts_instance_mask'
]
=
pts_instance_mask
results
[
'pts_semantic_mask'
]
=
pts_semantic_mask
return
results
def
__repr__
(
self
):
repr_str
=
self
.
__class__
.
__name__
repr_str
+=
'(num_points={})'
.
format
(
self
.
num_points
)
return
repr_str
mmdet3d/datasets/pipelines/loading.py
View file @
4f1a5e52
...
@@ -3,10 +3,10 @@ import os.path as osp
...
@@ -3,10 +3,10 @@ import os.path as osp
import
mmcv
import
mmcv
import
numpy
as
np
import
numpy
as
np
from
mmdet.datasets.
registry
import
PIPELINES
from
mmdet.datasets.
builder
import
PIPELINES
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
LoadPointsFromFile
(
object
):
class
LoadPointsFromFile
(
object
):
def
__init__
(
self
,
points_dim
=
4
,
with_reflectivity
=
True
):
def
__init__
(
self
,
points_dim
=
4
,
with_reflectivity
=
True
):
...
@@ -31,7 +31,7 @@ class LoadPointsFromFile(object):
...
@@ -31,7 +31,7 @@ class LoadPointsFromFile(object):
return
repr_str
return
repr_str
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
LoadMultiViewImageFromFiles
(
object
):
class
LoadMultiViewImageFromFiles
(
object
):
""" Load multi channel images from a list of separate channel files.
""" Load multi channel images from a list of separate channel files.
Expects results['filename'] to be a list of filenames
Expects results['filename'] to be a list of filenames
...
...
mmdet3d/datasets/pipelines/train_aug.py
View file @
4f1a5e52
import
mmcv
import
numpy
as
np
import
numpy
as
np
from
mmcv.utils
import
build_from_cfg
from
mmdet3d.core.bbox
import
box_np_ops
from
mmdet3d.core.bbox
import
box_np_ops
from
mmdet
3d.utils
import
build_from_cfg
from
mmdet
.datasets.builder
import
PIPELINES
from
mmdet.datasets.pipelines
import
RandomFlip
from
mmdet.datasets.pipelines
import
RandomFlip
from
mmdet.datasets.registry
import
PIPELINES
from
..registry
import
OBJECTSAMPLERS
from
..registry
import
OBJECTSAMPLERS
from
.data_augment_utils
import
noise_per_object_v3_
from
.data_augment_utils
import
noise_per_object_v3_
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
RandomFlip3D
(
RandomFlip
):
class
RandomFlip3D
(
RandomFlip
):
"""Flip the points & bbox.
"""Flip the points & bbox.
...
@@ -34,7 +35,42 @@ class RandomFlip3D(RandomFlip):
...
@@ -34,7 +35,42 @@ class RandomFlip3D(RandomFlip):
return
gt_bboxes_3d
,
points
return
gt_bboxes_3d
,
points
def
__call__
(
self
,
input_dict
):
def
__call__
(
self
,
input_dict
):
super
(
RandomFlip3D
,
self
).
__call__
(
input_dict
)
# filp 2D image and its annotations
if
'flip'
not
in
input_dict
:
flip
=
True
if
np
.
random
.
rand
()
<
self
.
flip_ratio
else
False
input_dict
[
'flip'
]
=
flip
if
'flip_direction'
not
in
input_dict
:
input_dict
[
'flip_direction'
]
=
self
.
direction
if
input_dict
[
'flip'
]:
# flip image
if
'img'
in
input_dict
:
if
isinstance
(
input_dict
[
'img'
],
list
):
input_dict
[
'img'
]
=
[
mmcv
.
imflip
(
img
,
direction
=
input_dict
[
'flip_direction'
])
for
img
in
input_dict
[
'img'
]
]
else
:
input_dict
[
'img'
]
=
mmcv
.
imflip
(
input_dict
[
'img'
],
direction
=
input_dict
[
'flip_direction'
])
# flip bboxes
for
key
in
input_dict
.
get
(
'bbox_fields'
,
[]):
input_dict
[
key
]
=
self
.
bbox_flip
(
input_dict
[
key
],
input_dict
[
'img_shape'
],
input_dict
[
'flip_direction'
])
# flip masks
for
key
in
input_dict
.
get
(
'mask_fields'
,
[]):
input_dict
[
key
]
=
[
mmcv
.
imflip
(
mask
,
direction
=
input_dict
[
'flip_direction'
])
for
mask
in
input_dict
[
key
]
]
# flip segs
for
key
in
input_dict
.
get
(
'seg_fields'
,
[]):
input_dict
[
key
]
=
mmcv
.
imflip
(
input_dict
[
key
],
direction
=
input_dict
[
'flip_direction'
])
if
self
.
sync_2d
:
if
self
.
sync_2d
:
input_dict
[
'pcd_flip'
]
=
input_dict
[
'flip'
]
input_dict
[
'pcd_flip'
]
=
input_dict
[
'flip'
]
else
:
else
:
...
@@ -50,8 +86,12 @@ class RandomFlip3D(RandomFlip):
...
@@ -50,8 +86,12 @@ class RandomFlip3D(RandomFlip):
input_dict
[
'points'
]
=
points
input_dict
[
'points'
]
=
points
return
input_dict
return
input_dict
def
__repr__
(
self
):
return
self
.
__class__
.
__name__
+
'(flip_ratio={}, sync_2d={})'
.
format
(
self
.
flip_ratio
,
self
.
sync_2d
)
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
ObjectSample
(
object
):
class
ObjectSample
(
object
):
def
__init__
(
self
,
db_sampler
,
sample_2d
=
False
):
def
__init__
(
self
,
db_sampler
,
sample_2d
=
False
):
...
@@ -128,7 +168,7 @@ class ObjectSample(object):
...
@@ -128,7 +168,7 @@ class ObjectSample(object):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
ObjectNoise
(
object
):
class
ObjectNoise
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -167,7 +207,7 @@ class ObjectNoise(object):
...
@@ -167,7 +207,7 @@ class ObjectNoise(object):
return
repr_str
return
repr_str
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
GlobalRotScale
(
object
):
class
GlobalRotScale
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -241,7 +281,7 @@ class GlobalRotScale(object):
...
@@ -241,7 +281,7 @@ class GlobalRotScale(object):
return
repr_str
return
repr_str
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
PointShuffle
(
object
):
class
PointShuffle
(
object
):
def
__call__
(
self
,
input_dict
):
def
__call__
(
self
,
input_dict
):
...
@@ -252,7 +292,7 @@ class PointShuffle(object):
...
@@ -252,7 +292,7 @@ class PointShuffle(object):
return
self
.
__class__
.
__name__
return
self
.
__class__
.
__name__
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
ObjectRangeFilter
(
object
):
class
ObjectRangeFilter
(
object
):
def
__init__
(
self
,
point_cloud_range
):
def
__init__
(
self
,
point_cloud_range
):
...
@@ -304,7 +344,7 @@ class ObjectRangeFilter(object):
...
@@ -304,7 +344,7 @@ class ObjectRangeFilter(object):
return
repr_str
return
repr_str
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
PointsRangeFilter
(
object
):
class
PointsRangeFilter
(
object
):
def
__init__
(
self
,
point_cloud_range
):
def
__init__
(
self
,
point_cloud_range
):
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment