Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
7906bd20
"vscode:/vscode.git/clone" did not exist on "18f48b73cb15ebaf33c5ad625ccbad54f561c7ae"
Commit
7906bd20
authored
Nov 28, 2018
by
wangg12
Browse files
support training on dataset with multiple ann_files
parent
64e310d5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
2 deletions
+53
-2
mmdet/datasets/__init__.py
mmdet/datasets/__init__.py
+2
-1
mmdet/datasets/concat_dataset.py
mmdet/datasets/concat_dataset.py
+30
-0
tools/train.py
tools/train.py
+21
-1
No files found.
mmdet/datasets/__init__.py
View file @
7906bd20
...
@@ -2,8 +2,9 @@ from .custom import CustomDataset
...
@@ -2,8 +2,9 @@ from .custom import CustomDataset
from
.coco
import
CocoDataset
from
.coco
import
CocoDataset
from
.loader
import
GroupSampler
,
DistributedGroupSampler
,
build_dataloader
from
.loader
import
GroupSampler
,
DistributedGroupSampler
,
build_dataloader
from
.utils
import
to_tensor
,
random_scale
,
show_ann
from
.utils
import
to_tensor
,
random_scale
,
show_ann
from
.concat_dataset
import
ConcatDataset
__all__
=
[
__all__
=
[
'CustomDataset'
,
'CocoDataset'
,
'GroupSampler'
,
'DistributedGroupSampler'
,
'CustomDataset'
,
'CocoDataset'
,
'GroupSampler'
,
'DistributedGroupSampler'
,
'ConcatDataset'
,
'build_dataloader'
,
'to_tensor'
,
'random_scale'
,
'show_ann'
'build_dataloader'
,
'to_tensor'
,
'random_scale'
,
'show_ann'
]
]
mmdet/datasets/concat_dataset.py
0 → 100644
View file @
7906bd20
import
bisect
import
numpy
as
np
from
torch.utils.data.dataset
import
ConcatDataset
as
_ConcatDataset
class
ConcatDataset
(
_ConcatDataset
):
"""
Same as torch.utils.data.dataset.ConcatDataset, but
concat the group flag for image aspect ratio.
"""
def
__init__
(
self
,
datasets
):
"""
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super
(
ConcatDataset
,
self
).
__init__
(
datasets
)
if
hasattr
(
datasets
[
0
],
'flag'
):
flags
=
[]
for
i
in
range
(
0
,
len
(
datasets
)):
flags
.
append
(
datasets
[
i
].
flag
)
self
.
flag
=
np
.
concatenate
(
flags
)
def
get_idxs
(
self
,
idx
):
dataset_idx
=
bisect
.
bisect_right
(
self
.
cumulative_sizes
,
idx
)
if
dataset_idx
==
0
:
sample_idx
=
idx
else
:
sample_idx
=
idx
-
self
.
cumulative_sizes
[
dataset_idx
-
1
]
return
dataset_idx
,
sample_idx
tools/train.py
View file @
7906bd20
from
__future__
import
division
from
__future__
import
division
import
argparse
import
argparse
import
copy
from
mmcv
import
Config
from
mmcv
import
Config
from
mmcv.runner
import
obj_from_dict
from
mmcv.runner
import
obj_from_dict
from
mmdet
import
datasets
,
__version__
from
mmdet
import
datasets
,
__version__
from
mmdet.datasets
import
ConcatDataset
from
mmdet.apis
import
(
train_detector
,
init_dist
,
get_root_logger
,
from
mmdet.apis
import
(
train_detector
,
init_dist
,
get_root_logger
,
set_random_seed
)
set_random_seed
)
from
mmdet.models
import
build_detector
from
mmdet.models
import
build_detector
...
@@ -36,6 +38,24 @@ def parse_args():
...
@@ -36,6 +38,24 @@ def parse_args():
return
args
return
args
def
get_train_dataset
(
cfg
):
if
isinstance
(
cfg
.
data
.
train
[
'ann_file'
],
list
)
or
isinstance
(
cfg
.
data
.
train
[
'ann_file'
],
tuple
):
ann_files
=
cfg
.
data
.
train
[
'ann_file'
]
train_datasets
=
[]
for
ann_file
in
ann_files
:
data_info
=
copy
.
deepcopy
(
cfg
.
data
.
train
)
data_info
[
'ann_file'
]
=
ann_file
train_dset
=
obj_from_dict
(
data_info
,
datasets
)
train_datasets
.
append
(
train_dset
)
if
len
(
train_datasets
)
>
1
:
train_dataset
=
ConcatDataset
(
train_datasets
)
else
:
train_dataset
=
train_datasets
[
0
]
else
:
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
return
train_dataset
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
...
@@ -67,7 +87,7 @@ def main():
...
@@ -67,7 +87,7 @@ def main():
model
=
build_detector
(
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
dataset
s
)
train_dataset
=
get_
train
_
dataset
(
cfg
)
train_detector
(
train_detector
(
model
,
model
,
train_dataset
,
train_dataset
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment