Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
2a7030a5
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "941ac9c3d9aab9c36fc33c58dac1980442928082"
Commit
2a7030a5
authored
Jul 17, 2022
by
VVsssssk
Committed by
ChaimZhu
Jul 20, 2022
Browse files
[Refactor]Support classes balance dataset
parent
c66197c7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
16 deletions
+37
-16
mmdet3d/datasets/__init__.py
mmdet3d/datasets/__init__.py
+14
-13
mmdet3d/datasets/dataset_wrappers.py
mmdet3d/datasets/dataset_wrappers.py
+6
-3
mmdet3d/datasets/det3d_dataset.py
mmdet3d/datasets/det3d_dataset.py
+17
-0
No files found.
mmdet3d/datasets/__init__.py
View file @
2a7030a5
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from
.builder
import
DATASETS
,
PIPELINES
,
build_dataset
from
.builder
import
DATASETS
,
PIPELINES
,
build_dataset
from
.dataset_wrappers
import
CBGSDataset
from
.det3d_dataset
import
Det3DDataset
from
.det3d_dataset
import
Det3DDataset
from
.kitti_dataset
import
KittiDataset
from
.kitti_dataset
import
KittiDataset
from
.kitti_mono_dataset
import
KittiMonoDataset
from
.kitti_mono_dataset
import
KittiMonoDataset
...
@@ -28,17 +29,17 @@ from .utils import get_loading_pipeline
...
@@ -28,17 +29,17 @@ from .utils import get_loading_pipeline
from
.waymo_dataset
import
WaymoDataset
from
.waymo_dataset
import
WaymoDataset
__all__
=
[
__all__
=
[
'KittiDataset'
,
'KittiMonoDataset'
,
'DATASETS'
,
'
build_d
ataset'
,
'KittiDataset'
,
'KittiMonoDataset'
,
'DATASETS'
,
'
CBGSD
ataset'
,
'NuScenesDataset'
,
'NuScenesMonoDataset'
,
'LyftDataset'
,
'ObjectSample'
,
'build_dataset'
,
'NuScenesDataset'
,
'NuScenesMonoDataset'
,
'LyftDataset'
,
'RandomFlip3D'
,
'ObjectNoise'
,
'GlobalRotScaleTrans'
,
'PointShuffle'
,
'ObjectSample'
,
'RandomFlip3D'
,
'ObjectNoise'
,
'GlobalRotScaleTrans'
,
'ObjectRangeFilter'
,
'PointsRangeFilter'
,
'LoadPointsFromFile'
,
'PointShuffle'
,
'ObjectRangeFilter'
,
'PointsRangeFilter'
,
'S3DISSegDataset'
,
'S3DISDataset'
,
'NormalizePointsColor'
,
'LoadPointsFromFile'
,
'S3DISSegDataset'
,
'S3DISDataset'
,
'IndoorPatchPointSample'
,
'IndoorPointSample'
,
'PointSample'
,
'NormalizePointsColor'
,
'IndoorPatchPointSample'
,
'IndoorPointSample'
,
'LoadAnnotations3D'
,
'GlobalAlignment'
,
'SUNRGBDDataset'
,
'ScanNetDataset'
,
'PointSample'
,
'LoadAnnotations3D'
,
'GlobalAlignment'
,
'SUNRGBDDataset'
,
'ScanNetSegDataset'
,
'ScanNetInstanceSegDataset'
,
'SemanticKITTIDataset'
,
'ScanNetDataset'
,
'ScanNetSegDataset'
,
'ScanNetInstanceSegDataset'
,
'
Det3D
Dataset'
,
'
Seg
3DDataset'
,
'
LoadPointsFromMultiSweeps
'
,
'
SemanticKITTI
Dataset'
,
'
Det
3DDataset'
,
'
Seg3DDataset
'
,
'WaymoDataset'
,
'BackgroundPointsFilter'
,
'VoxelBasedPointSampler'
,
'LoadPointsFromMultiSweeps'
,
'WaymoDataset'
,
'BackgroundPointsFilter'
,
'get_loading_pipeline'
,
'RandomDropPointsColor'
,
'RandomJitterPoints'
,
'VoxelBasedPointSampler'
,
'get_loading_pipeline'
,
'RandomDropPointsColor'
,
'ObjectNameFilter'
,
'AffineResize'
,
'RandomShiftScale'
,
'RandomJitterPoints'
,
'ObjectNameFilter'
,
'AffineResize'
,
'LoadPointsFromDict'
,
'PIPELINES'
'RandomShiftScale'
,
'LoadPointsFromDict'
,
'PIPELINES'
]
]
mmdet3d/datasets/dataset_wrappers.py
View file @
2a7030a5
...
@@ -17,8 +17,8 @@ class CBGSDataset(object):
...
@@ -17,8 +17,8 @@ class CBGSDataset(object):
"""
"""
def
__init__
(
self
,
dataset
):
def
__init__
(
self
,
dataset
):
self
.
dataset
=
dataset
self
.
dataset
=
DATASETS
.
build
(
dataset
)
self
.
CLASSES
=
dataset
.
CLASSES
self
.
CLASSES
=
self
.
dataset
.
metainfo
[
'
CLASSES
'
]
self
.
cat2id
=
{
name
:
i
for
i
,
name
in
enumerate
(
self
.
CLASSES
)}
self
.
cat2id
=
{
name
:
i
for
i
,
name
in
enumerate
(
self
.
CLASSES
)}
self
.
sample_indices
=
self
.
_get_sample_indices
()
self
.
sample_indices
=
self
.
_get_sample_indices
()
# self.dataset.data_infos = self.data_infos
# self.dataset.data_infos = self.data_infos
...
@@ -40,7 +40,10 @@ class CBGSDataset(object):
...
@@ -40,7 +40,10 @@ class CBGSDataset(object):
for
idx
in
range
(
len
(
self
.
dataset
)):
for
idx
in
range
(
len
(
self
.
dataset
)):
sample_cat_ids
=
self
.
dataset
.
get_cat_ids
(
idx
)
sample_cat_ids
=
self
.
dataset
.
get_cat_ids
(
idx
)
for
cat_id
in
sample_cat_ids
:
for
cat_id
in
sample_cat_ids
:
class_sample_idxs
[
cat_id
].
append
(
idx
)
if
cat_id
!=
-
1
:
# Filter categories that do not need to care.
# -1 indicate dontcare in MMDet3d.
class_sample_idxs
[
cat_id
].
append
(
idx
)
duplicated_samples
=
sum
(
duplicated_samples
=
sum
(
[
len
(
v
)
for
_
,
v
in
class_sample_idxs
.
items
()])
[
len
(
v
)
for
_
,
v
in
class_sample_idxs
.
items
()])
class_distribution
=
{
class_distribution
=
{
...
...
mmdet3d/datasets/det3d_dataset.py
View file @
2a7030a5
...
@@ -294,3 +294,20 @@ class Det3DDataset(BaseDataset):
...
@@ -294,3 +294,20 @@ class Det3DDataset(BaseDataset):
example
[
'data_sample'
].
gt_instances_3d
.
labels_3d
)
==
0
:
example
[
'data_sample'
].
gt_instances_3d
.
labels_3d
)
==
0
:
return
None
return
None
return
example
return
example
def
get_cat_ids
(
self
,
idx
:
int
)
->
List
[
int
]:
"""Get category ids by index. Dataset wrapped by ClassBalancedDataset
must implement this method.
The ``CBGSDataset`` or ``ClassBalancedDataset``requires a subclass
which implements this method.
Args:
idx (int): The index of data.
Returns:
set[int]: All categories in the sample of specified index.
"""
info
=
self
.
get_data_info
(
idx
)
gt_labels
=
info
[
'ann_info'
][
'gt_labels_3d'
].
tolist
()
return
set
(
gt_labels
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment