Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
c23c4208
Commit
c23c4208
authored
Jul 06, 2020
by
Shaoshuai Shi
Browse files
support class-balanced sampling for nuScenes dataset
parent
ea6cf247
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
3 deletions
+41
-3
pcdet/datasets/nuscenes/nuscenes_dataset.py
pcdet/datasets/nuscenes/nuscenes_dataset.py
+36
-0
tools/cfgs/dataset_configs/nuscenes_dataset.yaml
tools/cfgs/dataset_configs/nuscenes_dataset.yaml
+2
-0
tools/cfgs/nuscenes_models/second_multihead.yaml
tools/cfgs/nuscenes_models/second_multihead.yaml
+3
-3
No files found.
pcdet/datasets/nuscenes/nuscenes_dataset.py
View file @
c23c4208
...
...
@@ -16,6 +16,8 @@ class NuScenesDataset(DatasetTemplate):
)
self
.
infos
=
[]
self
.
include_nuscenes_data
(
self
.
mode
)
if
self
.
training
and
self
.
dataset_cfg
.
get
(
'BALANCED_RESAMPLING'
,
False
):
self
.
infos
=
self
.
balanced_infos_resampling
(
self
.
infos
)
def
include_nuscenes_data
(
self
,
mode
):
self
.
logger
.
info
(
'Loading NuScenes dataset'
)
...
...
@@ -32,6 +34,40 @@ class NuScenesDataset(DatasetTemplate):
self
.
infos
.
extend
(
nuscenes_infos
)
self
.
logger
.
info
(
'Total samples for NuScenes dataset: %d'
%
(
len
(
nuscenes_infos
)))
def
balanced_infos_resampling
(
self
,
infos
):
"""
Class-balanced sampling of nuScenes dataset from https://arxiv.org/abs/1908.09492
"""
cls_infos
=
{
name
:
[]
for
name
in
self
.
class_names
}
for
info
in
infos
:
for
name
in
set
(
info
[
'gt_names'
]):
if
name
in
self
.
class_names
:
cls_infos
[
name
].
append
(
info
)
duplicated_samples
=
sum
([
len
(
v
)
for
_
,
v
in
cls_infos
.
items
()])
cls_dist
=
{
k
:
len
(
v
)
/
duplicated_samples
for
k
,
v
in
cls_infos
.
items
()}
sampled_infos
=
[]
frac
=
1.0
/
len
(
self
.
class_names
)
ratios
=
[
frac
/
v
for
v
in
cls_dist
.
values
()]
for
cur_cls_infos
,
ratio
in
zip
(
list
(
cls_infos
.
values
()),
ratios
):
sampled_infos
+=
np
.
random
.
choice
(
cur_cls_infos
,
int
(
len
(
cur_cls_infos
)
*
ratio
)
).
tolist
()
self
.
logger
.
info
(
'Total samples after balanced resampling: %s'
%
(
len
(
sampled_infos
)))
cls_infos_new
=
{
name
:
[]
for
name
in
self
.
class_names
}
for
info
in
sampled_infos
:
for
name
in
set
(
info
[
'gt_names'
]):
if
name
in
self
.
class_names
:
cls_infos_new
[
name
].
append
(
info
)
cls_dist_new
=
{
k
:
len
(
v
)
/
len
(
sampled_infos
)
for
k
,
v
in
cls_infos_new
.
items
()}
return
sampled_infos
def
get_sweep
(
self
,
sweep_info
):
def
remove_ego_points
(
points
,
center_radius
=
1.0
):
mask
=
~
((
np
.
abs
(
points
[:,
0
])
<
center_radius
)
&
(
np
.
abs
(
points
[:,
1
])
<
center_radius
))
...
...
tools/cfgs/dataset_configs/nuscenes_dataset.yaml
View file @
c23c4208
...
...
@@ -17,6 +17,8 @@ INFO_PATH: {
POINT_CLOUD_RANGE
:
[
-51.2
,
-51.2
,
-5.0
,
51.2
,
51.2
,
3.0
]
BALANCED_RESAMPLING
:
True
DATA_AUGMENTOR
:
DISABLE_AUG_LIST
:
[
'
placeholder'
]
AUG_CONFIG_LIST
:
...
...
tools/cfgs/nuscenes_models/second_multihead.yaml
View file @
c23c4208
...
...
@@ -221,13 +221,13 @@ MODEL:
MULTI_CLASSES_NMS
:
False
NMS_TYPE
:
nms_gpu
NMS_THRESH
:
0.2
NMS_PRE_MAXSIZE
:
4096
NMS_POST_MAXSIZE
:
250
NMS_PRE_MAXSIZE
:
1000
NMS_POST_MAXSIZE
:
100
OPTIMIZATION
:
OPTIMIZER
:
adam_onecycle
LR
:
0.003
LR
:
0.003
WEIGHT_DECAY
:
0.01
MOMENTUM
:
0.9
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment