Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
2a14897e
Commit
2a14897e
authored
Apr 27, 2020
by
zhangwenwei
Browse files
Fix kitti evaluation bug
parent
aec41c7f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
9 deletions
+8
-9
mmdet3d/core/bbox/box_torch_ops.py
mmdet3d/core/bbox/box_torch_ops.py
+4
-4
mmdet3d/datasets/kitti_dataset.py
mmdet3d/datasets/kitti_dataset.py
+4
-5
No files found.
mmdet3d/core/bbox/box_torch_ops.py
View file @
2a14897e
...
...
@@ -21,9 +21,9 @@ def corners_nd(dims, origin=0.5):
where x0 < x1, y0 < y1, z0 < z1
"""
ndim
=
int
(
dims
.
shape
[
1
])
corners_norm
=
np
.
stack
(
np
.
unravel_index
(
np
.
arange
(
2
**
ndim
),
[
2
]
*
ndim
),
axis
=
1
).
as
type
(
dims
.
dtype
)
corners_norm
=
torch
.
from_numpy
(
np
.
stack
(
np
.
unravel_index
(
np
.
arange
(
2
**
ndim
),
[
2
]
*
ndim
),
axis
=
1
)).
to
(
device
=
dims
.
device
,
d
type
=
dims
.
dtype
)
# now corners_norm has format: (2d) x0y0, x0y1, x1y0, x1y1
# (3d) x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1
# so need to convert to a format which is convenient to do other computing.
...
...
@@ -34,7 +34,7 @@ def corners_nd(dims, origin=0.5):
corners_norm
=
corners_norm
[[
0
,
1
,
3
,
2
]]
elif
ndim
==
3
:
corners_norm
=
corners_norm
[[
0
,
1
,
3
,
2
,
4
,
5
,
7
,
6
]]
corners_norm
=
corners_norm
-
np
.
array
(
origin
,
dtype
=
dims
.
dtype
)
corners_norm
=
corners_norm
-
dims
.
new_tensor
(
origin
)
corners
=
dims
.
reshape
([
-
1
,
1
,
ndim
])
*
corners_norm
.
reshape
(
[
1
,
2
**
ndim
,
ndim
])
return
corners
...
...
mmdet3d/datasets/kitti_dataset.py
View file @
2a14897e
...
...
@@ -44,8 +44,7 @@ class KittiDataset(torch_data.Dataset):
self
.
pcd_limit_range
=
[
0
,
-
40
,
-
3
,
70.4
,
40
,
0.0
]
self
.
ann_file
=
ann_file
with
open
(
ann_file
,
'rb'
)
as
f
:
self
.
kitti_infos
=
mmcv
.
load
(
f
)
self
.
kitti_infos
=
mmcv
.
load
(
ann_file
)
# set group flag for the sampler
if
not
self
.
test_mode
:
...
...
@@ -284,7 +283,7 @@ class KittiDataset(torch_data.Dataset):
result_files
=
self
.
bbox2result_kitti
(
outputs
,
self
.
class_names
,
pklfile_prefix
,
submission_prefix
)
return
result_files
return
result_files
,
tmp_dir
def
evaluate
(
self
,
results
,
...
...
@@ -321,7 +320,7 @@ class KittiDataset(torch_data.Dataset):
if
tmp_dir
is
not
None
:
tmp_dir
.
cleanup
()
return
ap_dict
return
ap_dict
,
tmp_dir
def
bbox2result_kitti
(
self
,
net_outputs
,
...
...
@@ -332,7 +331,7 @@ class KittiDataset(torch_data.Dataset):
mmcv
.
mkdir_or_exist
(
submission_prefix
)
det_annos
=
[]
print
(
'Converting prediction to KITTI format'
)
print
(
'
\n
Converting prediction to KITTI format'
)
for
idx
,
pred_dicts
in
enumerate
(
mmcv
.
track_iter_progress
(
net_outputs
)):
annos
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment