Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
7cbdbc78
Commit
7cbdbc78
authored
Nov 28, 2018
by
wangg12
Browse files
move the function to datasets.utils
parent
7906bd20
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
25 deletions
+25
-25
mmdet/datasets/__init__.py
mmdet/datasets/__init__.py
+2
-2
mmdet/datasets/utils.py
mmdet/datasets/utils.py
+22
-2
tools/train.py
tools/train.py
+1
-21
No files found.
mmdet/datasets/__init__.py
View file @
7cbdbc78
from
.custom
import
CustomDataset
from
.custom
import
CustomDataset
from
.coco
import
CocoDataset
from
.coco
import
CocoDataset
from
.loader
import
GroupSampler
,
DistributedGroupSampler
,
build_dataloader
from
.loader
import
GroupSampler
,
DistributedGroupSampler
,
build_dataloader
from
.utils
import
to_tensor
,
random_scale
,
show_ann
from
.utils
import
to_tensor
,
random_scale
,
show_ann
,
get_dataset
from
.concat_dataset
import
ConcatDataset
from
.concat_dataset
import
ConcatDataset
__all__
=
[
__all__
=
[
'CustomDataset'
,
'CocoDataset'
,
'GroupSampler'
,
'DistributedGroupSampler'
,
'ConcatDataset'
,
'CustomDataset'
,
'CocoDataset'
,
'GroupSampler'
,
'DistributedGroupSampler'
,
'ConcatDataset'
,
'build_dataloader'
,
'to_tensor'
,
'random_scale'
,
'show_ann'
'build_dataloader'
,
'to_tensor'
,
'random_scale'
,
'show_ann'
,
'get_dataset'
]
]
mmdet/datasets/utils.py
View file @
7cbdbc78
from
collections
import
Sequence
from
collections
import
Sequence
import
copy
import
mmcv
import
mmcv
from
mmcv.runner
import
obj_from_dict
import
torch
import
torch
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
from
.concat_dataset
import
ConcatDataset
from
..
import
datasets
def
to_tensor
(
data
):
def
to_tensor
(
data
):
"""Convert objects of various python types to :obj:`torch.Tensor`.
"""Convert objects of various python types to :obj:`torch.Tensor`.
...
@@ -67,3 +69,21 @@ def show_ann(coco, img, ann_info):
...
@@ -67,3 +69,21 @@ def show_ann(coco, img, ann_info):
plt
.
axis
(
'off'
)
plt
.
axis
(
'off'
)
coco
.
showAnns
(
ann_info
)
coco
.
showAnns
(
ann_info
)
plt
.
show
()
plt
.
show
()
def
get_dataset
(
data_cfg
):
if
isinstance
(
data_cfg
[
'ann_file'
],
list
)
or
isinstance
(
data_cfg
[
'ann_file'
],
tuple
):
ann_files
=
data_cfg
[
'ann_file'
]
dsets
=
[]
for
ann_file
in
ann_files
:
data_info
=
copy
.
deepcopy
(
data_cfg
)
data_info
[
'ann_file'
]
=
ann_file
dset
=
obj_from_dict
(
data_info
,
datasets
)
dsets
.
append
(
dset
)
if
len
(
dsets
)
>
1
:
dset
=
ConcatDataset
(
dsets
)
else
:
dset
=
dsets
[
0
]
else
:
dset
=
obj_from_dict
(
data_cfg
,
datasets
)
return
dset
\ No newline at end of file
tools/train.py
View file @
7cbdbc78
from
__future__
import
division
from
__future__
import
division
import
argparse
import
argparse
import
copy
from
mmcv
import
Config
from
mmcv
import
Config
from
mmcv.runner
import
obj_from_dict
from
mmcv.runner
import
obj_from_dict
from
mmdet
import
datasets
,
__version__
from
mmdet
import
datasets
,
__version__
from
mmdet.datasets
import
ConcatDataset
from
mmdet.apis
import
(
train_detector
,
init_dist
,
get_root_logger
,
from
mmdet.apis
import
(
train_detector
,
init_dist
,
get_root_logger
,
set_random_seed
)
set_random_seed
)
from
mmdet.models
import
build_detector
from
mmdet.models
import
build_detector
...
@@ -38,24 +36,6 @@ def parse_args():
...
@@ -38,24 +36,6 @@ def parse_args():
return
args
return
args
def
get_train_dataset
(
cfg
):
if
isinstance
(
cfg
.
data
.
train
[
'ann_file'
],
list
)
or
isinstance
(
cfg
.
data
.
train
[
'ann_file'
],
tuple
):
ann_files
=
cfg
.
data
.
train
[
'ann_file'
]
train_datasets
=
[]
for
ann_file
in
ann_files
:
data_info
=
copy
.
deepcopy
(
cfg
.
data
.
train
)
data_info
[
'ann_file'
]
=
ann_file
train_dset
=
obj_from_dict
(
data_info
,
datasets
)
train_datasets
.
append
(
train_dset
)
if
len
(
train_datasets
)
>
1
:
train_dataset
=
ConcatDataset
(
train_datasets
)
else
:
train_dataset
=
train_datasets
[
0
]
else
:
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
return
train_dataset
def
main
():
def
main
():
args
=
parse_args
()
args
=
parse_args
()
...
@@ -87,7 +67,7 @@ def main():
...
@@ -87,7 +67,7 @@ def main():
model
=
build_detector
(
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
train_dataset
=
get_train_dataset
(
cfg
)
train_dataset
=
datasets
.
get_dataset
(
cfg
.
data
.
train
)
train_detector
(
train_detector
(
model
,
model
,
train_dataset
,
train_dataset
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment