Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
9510c3a7
Unverified
Commit
9510c3a7
authored
Nov 30, 2018
by
Kai Chen
Committed by
GitHub
Nov 30, 2018
Browse files
Merge pull request #127 from wangg12/master
support training on dataset with multiple ann_files
parents
ae4646fa
9baf0a8b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
71 additions
and
5 deletions
+71
-5
.gitignore
.gitignore
+2
-0
mmdet/datasets/__init__.py
mmdet/datasets/__init__.py
+4
-2
mmdet/datasets/concat_dataset.py
mmdet/datasets/concat_dataset.py
+20
-0
mmdet/datasets/utils.py
mmdet/datasets/utils.py
+42
-0
tools/train.py
tools/train.py
+3
-3
No files found.
.gitignore
View file @
9510c3a7
...
@@ -107,3 +107,5 @@ venv.bak/
...
@@ -107,3 +107,5 @@ venv.bak/
mmdet/ops/nms/*.cpp
mmdet/ops/nms/*.cpp
mmdet/version.py
mmdet/version.py
data
data
.vscode
.idea
mmdet/datasets/__init__.py
View file @
9510c3a7
from
.custom
import
CustomDataset
from
.custom
import
CustomDataset
from
.coco
import
CocoDataset
from
.coco
import
CocoDataset
from
.loader
import
GroupSampler
,
DistributedGroupSampler
,
build_dataloader
from
.loader
import
GroupSampler
,
DistributedGroupSampler
,
build_dataloader
from
.utils
import
to_tensor
,
random_scale
,
show_ann
from
.utils
import
to_tensor
,
random_scale
,
show_ann
,
get_dataset
from
.concat_dataset
import
ConcatDataset
__all__
=
[
__all__
=
[
'CustomDataset'
,
'CocoDataset'
,
'GroupSampler'
,
'DistributedGroupSampler'
,
'CustomDataset'
,
'CocoDataset'
,
'GroupSampler'
,
'DistributedGroupSampler'
,
'build_dataloader'
,
'to_tensor'
,
'random_scale'
,
'show_ann'
'ConcatDataset'
,
'build_dataloader'
,
'to_tensor'
,
'random_scale'
,
'show_ann'
,
'get_dataset'
]
]
mmdet/datasets/concat_dataset.py
0 → 100644
View file @
9510c3a7
import
numpy
as
np
from
torch.utils.data.dataset
import
ConcatDataset
as
_ConcatDataset
class
ConcatDataset
(
_ConcatDataset
):
"""
Same as torch.utils.data.dataset.ConcatDataset, but
concat the group flag for image aspect ratio.
"""
def
__init__
(
self
,
datasets
):
"""
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super
(
ConcatDataset
,
self
).
__init__
(
datasets
)
if
hasattr
(
datasets
[
0
],
'flag'
):
flags
=
[]
for
i
in
range
(
0
,
len
(
datasets
)):
flags
.
append
(
datasets
[
i
].
flag
)
self
.
flag
=
np
.
concatenate
(
flags
)
mmdet/datasets/utils.py
View file @
9510c3a7
import
copy
from
collections
import
Sequence
from
collections
import
Sequence
import
mmcv
import
mmcv
from
mmcv.runner
import
obj_from_dict
import
torch
import
torch
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
from
.concat_dataset
import
ConcatDataset
from
..
import
datasets
def
to_tensor
(
data
):
def
to_tensor
(
data
):
...
@@ -67,3 +71,41 @@ def show_ann(coco, img, ann_info):
...
@@ -67,3 +71,41 @@ def show_ann(coco, img, ann_info):
plt
.
axis
(
'off'
)
plt
.
axis
(
'off'
)
coco
.
showAnns
(
ann_info
)
coco
.
showAnns
(
ann_info
)
plt
.
show
()
plt
.
show
()
def
get_dataset
(
data_cfg
):
if
isinstance
(
data_cfg
[
'ann_file'
],
(
list
,
tuple
)):
ann_files
=
data_cfg
[
'ann_file'
]
num_dset
=
len
(
ann_files
)
else
:
ann_files
=
[
data_cfg
[
'ann_file'
]]
num_dset
=
1
if
'proposal_file'
in
data_cfg
.
keys
():
if
isinstance
(
data_cfg
[
'proposal_file'
],
(
list
,
tuple
)):
proposal_files
=
data_cfg
[
'proposal_file'
]
else
:
proposal_files
=
[
data_cfg
[
'proposal_file'
]]
else
:
proposal_files
=
[
None
]
*
num_dset
assert
len
(
proposal_files
)
==
num_dset
if
isinstance
(
data_cfg
[
'img_prefix'
],
(
list
,
tuple
)):
img_prefixes
=
data_cfg
[
'img_prefix'
]
else
:
img_prefixes
=
[
data_cfg
[
'img_prefix'
]]
*
num_dset
assert
len
(
img_prefixes
)
==
num_dset
dsets
=
[]
for
i
in
range
(
num_dset
):
data_info
=
copy
.
deepcopy
(
data_cfg
)
data_info
[
'ann_file'
]
=
ann_files
[
i
]
data_info
[
'proposal_file'
]
=
proposal_files
[
i
]
data_info
[
'img_prefix'
]
=
img_prefixes
[
i
]
dset
=
obj_from_dict
(
data_info
,
datasets
)
dsets
.
append
(
dset
)
if
len
(
dsets
)
>
1
:
dset
=
ConcatDataset
(
dsets
)
else
:
dset
=
dsets
[
0
]
return
dset
tools/train.py
View file @
9510c3a7
...
@@ -2,9 +2,9 @@ from __future__ import division
...
@@ -2,9 +2,9 @@ from __future__ import division
import
argparse
import
argparse
from
mmcv
import
Config
from
mmcv
import
Config
from
mmcv.runner
import
obj_from_dict
from
mmdet
import
datasets
,
__version__
from
mmdet
import
__version__
from
mmdet.datasets
import
get_dataset
from
mmdet.apis
import
(
train_detector
,
init_dist
,
get_root_logger
,
from
mmdet.apis
import
(
train_detector
,
init_dist
,
get_root_logger
,
set_random_seed
)
set_random_seed
)
from
mmdet.models
import
build_detector
from
mmdet.models
import
build_detector
...
@@ -67,7 +67,7 @@ def main():
...
@@ -67,7 +67,7 @@ def main():
model
=
build_detector
(
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
train_dataset
=
obj_from_dic
t
(
cfg
.
data
.
train
,
datasets
)
train_dataset
=
get_datase
t
(
cfg
.
data
.
train
)
train_detector
(
train_detector
(
model
,
model
,
train_dataset
,
train_dataset
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment