Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
8b47a12b
Commit
8b47a12b
authored
Oct 08, 2018
by
Kai Chen
Browse files
minor updates for train/test scripts
parent
f8dab59d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
7 deletions
+8
-7
tools/test.py
tools/test.py
+3
-4
tools/train.py
tools/train.py
+5
-3
No files found.
tools/test.py
View file @
8b47a12b
...
@@ -44,17 +44,16 @@ def parse_args():
...
@@ -44,17 +44,16 @@ def parse_args():
'--eval'
,
'--eval'
,
type
=
str
,
type
=
str
,
nargs
=
'+'
,
nargs
=
'+'
,
choices
=
[
'proposal'
,
'bbox'
,
'segm'
,
'keypoints'
],
choices
=
[
'proposal'
,
'proposal_fast'
,
'bbox'
,
'segm'
,
'keypoints'
],
help
=
'eval types'
)
help
=
'eval types'
)
parser
.
add_argument
(
'--show'
,
action
=
'store_true'
,
help
=
'show results'
)
parser
.
add_argument
(
'--show'
,
action
=
'store_true'
,
help
=
'show results'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
args
=
parse_args
()
def
main
():
def
main
():
args
=
parse_args
()
cfg
=
mmcv
.
Config
.
fromfile
(
args
.
config
)
cfg
=
mmcv
.
Config
.
fromfile
(
args
.
config
)
cfg
.
model
.
pretrained
=
None
cfg
.
model
.
pretrained
=
None
cfg
.
data
.
test
.
test_mode
=
True
cfg
.
data
.
test
.
test_mode
=
True
...
...
tools/train.py
View file @
8b47a12b
...
@@ -2,6 +2,7 @@ from __future__ import division
...
@@ -2,6 +2,7 @@ from __future__ import division
import
argparse
import
argparse
import
logging
import
logging
import
random
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
numpy
as
np
import
numpy
as
np
...
@@ -55,6 +56,7 @@ def get_logger(log_level):
...
@@ -55,6 +56,7 @@ def get_logger(log_level):
def
set_random_seed
(
seed
):
def
set_random_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
...
@@ -89,7 +91,7 @@ def main():
...
@@ -89,7 +91,7 @@ def main():
if
args
.
work_dir
is
not
None
:
if
args
.
work_dir
is
not
None
:
cfg
.
work_dir
=
args
.
work_dir
cfg
.
work_dir
=
args
.
work_dir
cfg
.
gpus
=
args
.
gpus
cfg
.
gpus
=
args
.
gpus
#
add
mmdet version
to
checkpoint as meta data
#
save
mmdet version
in
checkpoint as meta data
cfg
.
checkpoint_config
.
meta
=
dict
(
cfg
.
checkpoint_config
.
meta
=
dict
(
mmdet_version
=
__version__
,
config
=
cfg
.
text
)
mmdet_version
=
__version__
,
config
=
cfg
.
text
)
...
@@ -103,13 +105,13 @@ def main():
...
@@ -103,13 +105,13 @@ def main():
# init distributed environment if necessary
# init distributed environment if necessary
if
args
.
launcher
==
'none'
:
if
args
.
launcher
==
'none'
:
dist
=
False
dist
=
False
logger
.
info
(
'
Disabled
distributed training.'
)
logger
.
info
(
'
Non-
distributed training.'
)
else
:
else
:
dist
=
True
dist
=
True
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
init_dist
(
args
.
launcher
,
**
cfg
.
dist_params
)
if
torch
.
distributed
.
get_rank
()
!=
0
:
if
torch
.
distributed
.
get_rank
()
!=
0
:
logger
.
setLevel
(
'ERROR'
)
logger
.
setLevel
(
'ERROR'
)
logger
.
info
(
'
Enabled d
istributed training.'
)
logger
.
info
(
'
D
istributed training.'
)
# prepare data loaders
# prepare data loaders
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment