Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
9baf0a8b
"vscode:/vscode.git/clone" did not exist on "5df0cd3074fcf040c8c5948a6d7a4b691dd7af76"
Commit
9baf0a8b
authored
Nov 30, 2018
by
wangg12
Browse files
fix some problems; support multiple proposal_files and img_prefixes
parent
6b25743a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
24 deletions
+38
-24
.gitignore
.gitignore
+2
-0
mmdet/datasets/concat_dataset.py
mmdet/datasets/concat_dataset.py
+0
-9
mmdet/datasets/utils.py
mmdet/datasets/utils.py
+33
-13
tools/train.py
tools/train.py
+3
-2
No files found.
.gitignore
View file @
9baf0a8b
...
@@ -107,3 +107,5 @@ venv.bak/
...
@@ -107,3 +107,5 @@ venv.bak/
mmdet/ops/nms/*.cpp
mmdet/ops/nms/*.cpp
mmdet/version.py
mmdet/version.py
data
data
.vscode
.idea
mmdet/datasets/concat_dataset.py
View file @
9baf0a8b
import
bisect
import
numpy
as
np
import
numpy
as
np
from
torch.utils.data.dataset
import
ConcatDataset
as
_ConcatDataset
from
torch.utils.data.dataset
import
ConcatDataset
as
_ConcatDataset
...
@@ -19,11 +18,3 @@ class ConcatDataset(_ConcatDataset):
...
@@ -19,11 +18,3 @@ class ConcatDataset(_ConcatDataset):
for
i
in
range
(
0
,
len
(
datasets
)):
for
i
in
range
(
0
,
len
(
datasets
)):
flags
.
append
(
datasets
[
i
].
flag
)
flags
.
append
(
datasets
[
i
].
flag
)
self
.
flag
=
np
.
concatenate
(
flags
)
self
.
flag
=
np
.
concatenate
(
flags
)
def
get_idxs
(
self
,
idx
):
dataset_idx
=
bisect
.
bisect_right
(
self
.
cumulative_sizes
,
idx
)
if
dataset_idx
==
0
:
sample_idx
=
idx
else
:
sample_idx
=
idx
-
self
.
cumulative_sizes
[
dataset_idx
-
1
]
return
dataset_idx
,
sample_idx
mmdet/datasets/utils.py
View file @
9baf0a8b
from
collections
import
Sequence
import
copy
import
copy
from
collections
import
Sequence
import
mmcv
import
mmcv
from
mmcv.runner
import
obj_from_dict
from
mmcv.runner
import
obj_from_dict
import
torch
import
torch
...
@@ -73,19 +74,38 @@ def show_ann(coco, img, ann_info):
...
@@ -73,19 +74,38 @@ def show_ann(coco, img, ann_info):
def
get_dataset
(
data_cfg
):
def
get_dataset
(
data_cfg
):
if
isinstance
(
data_cfg
[
'ann_file'
],
list
)
or
\
if
isinstance
(
data_cfg
[
'ann_file'
],
(
list
,
tuple
)):
isinstance
(
data_cfg
[
'ann_file'
],
tuple
):
ann_files
=
data_cfg
[
'ann_file'
]
ann_files
=
data_cfg
[
'ann_file'
]
num_dset
=
len
(
ann_files
)
else
:
ann_files
=
[
data_cfg
[
'ann_file'
]]
num_dset
=
1
if
'proposal_file'
in
data_cfg
.
keys
():
if
isinstance
(
data_cfg
[
'proposal_file'
],
(
list
,
tuple
)):
proposal_files
=
data_cfg
[
'proposal_file'
]
else
:
proposal_files
=
[
data_cfg
[
'proposal_file'
]]
else
:
proposal_files
=
[
None
]
*
num_dset
assert
len
(
proposal_files
)
==
num_dset
if
isinstance
(
data_cfg
[
'img_prefix'
],
(
list
,
tuple
)):
img_prefixes
=
data_cfg
[
'img_prefix'
]
else
:
img_prefixes
=
[
data_cfg
[
'img_prefix'
]]
*
num_dset
assert
len
(
img_prefixes
)
==
num_dset
dsets
=
[]
dsets
=
[]
for
ann_file
in
an
n_files
:
for
i
in
r
an
ge
(
num_dset
)
:
data_info
=
copy
.
deepcopy
(
data_cfg
)
data_info
=
copy
.
deepcopy
(
data_cfg
)
data_info
[
'ann_file'
]
=
ann_file
data_info
[
'ann_file'
]
=
ann_files
[
i
]
data_info
[
'proposal_file'
]
=
proposal_files
[
i
]
data_info
[
'img_prefix'
]
=
img_prefixes
[
i
]
dset
=
obj_from_dict
(
data_info
,
datasets
)
dset
=
obj_from_dict
(
data_info
,
datasets
)
dsets
.
append
(
dset
)
dsets
.
append
(
dset
)
if
len
(
dsets
)
>
1
:
if
len
(
dsets
)
>
1
:
dset
=
ConcatDataset
(
dsets
)
dset
=
ConcatDataset
(
dsets
)
else
:
else
:
dset
=
dsets
[
0
]
dset
=
dsets
[
0
]
else
:
dset
=
obj_from_dict
(
data_cfg
,
datasets
)
return
dset
return
dset
tools/train.py
View file @
9baf0a8b
...
@@ -3,7 +3,8 @@ from __future__ import division
...
@@ -3,7 +3,8 @@ from __future__ import division
import
argparse
import
argparse
from
mmcv
import
Config
from
mmcv
import
Config
from
mmdet
import
datasets
,
__version__
from
mmdet
import
__version__
from
mmdet.datasets
import
get_dataset
from
mmdet.apis
import
(
train_detector
,
init_dist
,
get_root_logger
,
from
mmdet.apis
import
(
train_detector
,
init_dist
,
get_root_logger
,
set_random_seed
)
set_random_seed
)
from
mmdet.models
import
build_detector
from
mmdet.models
import
build_detector
...
@@ -66,7 +67,7 @@ def main():
...
@@ -66,7 +67,7 @@ def main():
model
=
build_detector
(
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
train_dataset
=
datasets
.
get_dataset
(
cfg
.
data
.
train
)
train_dataset
=
get_dataset
(
cfg
.
data
.
train
)
train_detector
(
train_detector
(
model
,
model
,
train_dataset
,
train_dataset
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment