Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
300f7157
Commit
300f7157
authored
Sep 28, 2018
by
Kai Chen
Browse files
allow manually setting random seeds
parent
143a8372
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
0 deletions
+13
-0
tools/train.py
tools/train.py
+13
-0
No files found.
tools/train.py
View file @
300f7157
...
...
@@ -4,6 +4,7 @@ import argparse
import
logging
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
from
mmcv
import
Config
from
mmcv.torchpack
import
Runner
,
obj_from_dict
...
...
@@ -53,6 +54,12 @@ def get_logger(log_level):
return
logger
def
set_random_seed
(
seed
):
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train a detector'
)
parser
.
add_argument
(
'config'
,
help
=
'train config file path'
)
...
...
@@ -63,6 +70,7 @@ def parse_args():
help
=
'whether to add a validate phase'
)
parser
.
add_argument
(
'--gpus'
,
type
=
int
,
default
=
1
,
help
=
'number of gpus to use'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
help
=
'random seed'
)
parser
.
add_argument
(
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
...
...
@@ -84,6 +92,11 @@ def main():
logger
=
get_logger
(
cfg
.
log_level
)
# set random seed if specified
if
args
.
seed
is
not
None
:
logger
.
info
(
'Set random seed to {}'
.
format
(
args
.
seed
))
set_random_seed
(
args
.
seed
)
# init distributed environment if necessary
if
args
.
launcher
==
'none'
:
dist
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment