Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
16a6f7da
Unverified
Commit
16a6f7da
authored
Dec 25, 2019
by
Kai Chen
Committed by
GitHub
Dec 25, 2019
Browse files
add some docstring (#1869)
parent
629b9ff2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
9 deletions
+63
-9
mmdet/apis/train.py
mmdet/apis/train.py
+34
-7
mmdet/datasets/loader/build_loader.py
mmdet/datasets/loader/build_loader.py
+22
-0
tools/train.py
tools/train.py
+7
-2
No files found.
mmdet/apis/train.py
View file @
16a6f7da
...
@@ -17,13 +17,6 @@ from mmdet.datasets import DATASETS, build_dataloader
...
@@ -17,13 +17,6 @@ from mmdet.datasets import DATASETS, build_dataloader
from
mmdet.models
import
RPN
from
mmdet.models
import
RPN
def
set_random_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
def
get_root_logger
(
log_file
=
None
,
log_level
=
logging
.
INFO
):
def
get_root_logger
(
log_file
=
None
,
log_level
=
logging
.
INFO
):
logger
=
logging
.
getLogger
(
'mmdet'
)
logger
=
logging
.
getLogger
(
'mmdet'
)
# if the logger has been initialized, just return it
# if the logger has been initialized, just return it
...
@@ -45,6 +38,25 @@ def get_root_logger(log_file=None, log_level=logging.INFO):
...
@@ -45,6 +38,25 @@ def get_root_logger(log_file=None, log_level=logging.INFO):
return
logger
return
logger
def
set_random_seed
(
seed
,
deterministic
=
False
):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
if
deterministic
:
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
def
parse_losses
(
losses
):
def
parse_losses
(
losses
):
log_vars
=
OrderedDict
()
log_vars
=
OrderedDict
()
for
loss_name
,
loss_value
in
losses
.
items
():
for
loss_name
,
loss_value
in
losses
.
items
():
...
@@ -70,6 +82,21 @@ def parse_losses(losses):
...
@@ -70,6 +82,21 @@ def parse_losses(losses):
def
batch_processor
(
model
,
data
,
train_mode
):
def
batch_processor
(
model
,
data
,
train_mode
):
"""Process a data batch.
This method is required as an argument of Runner, which defines how to
process a data batch and obtain proper outputs. The first 3 arguments of
batch_processor are fixed.
Args:
model (nn.Module): A PyTorch model.
data (dict): The data batch in a dict.
train_mode (bool): Training mode or not. It may be useless for some
models.
Returns:
dict: A dict containing losses and log vars.
"""
losses
=
model
(
**
data
)
losses
=
model
(
**
data
)
loss
,
log_vars
=
parse_losses
(
losses
)
loss
,
log_vars
=
parse_losses
(
losses
)
...
...
mmdet/datasets/loader/build_loader.py
View file @
16a6f7da
...
@@ -21,8 +21,30 @@ def build_dataloader(dataset,
...
@@ -21,8 +21,30 @@ def build_dataloader(dataset,
dist
=
True
,
dist
=
True
,
shuffle
=
True
,
shuffle
=
True
,
**
kwargs
):
**
kwargs
):
"""Build PyTorch DataLoader.
In distributed training, each GPU/process has a dataloader.
In non-distributed training, there is only one dataloader for all GPUs.
Args:
dataset (Dataset): A PyTorch dataset.
imgs_per_gpu (int): Number of images on each GPU, i.e., batch size of
each GPU.
workers_per_gpu (int): How many subprocesses to use for data loading
for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed training.
dist (bool): Distributed training/test or not. Default: True.
shuffle (bool): Whether to shuffle the data at every epoch.
Default: True.
kwargs: any keyword argument to be used to initialize DataLoader
Returns:
DataLoader: A PyTorch dataloader.
"""
if
dist
:
if
dist
:
rank
,
world_size
=
get_dist_info
()
rank
,
world_size
=
get_dist_info
()
# DistributedGroupSampler will definitely shuffle the data to satisfy
# that images on each GPU are in the same group
if
shuffle
:
if
shuffle
:
sampler
=
DistributedGroupSampler
(
dataset
,
imgs_per_gpu
,
sampler
=
DistributedGroupSampler
(
dataset
,
imgs_per_gpu
,
world_size
,
rank
)
world_size
,
rank
)
...
...
tools/train.py
View file @
16a6f7da
...
@@ -32,6 +32,10 @@ def parse_args():
...
@@ -32,6 +32,10 @@ def parse_args():
help
=
'number of gpus to use '
help
=
'number of gpus to use '
'(only applicable to non-distributed training)'
)
'(only applicable to non-distributed training)'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
None
,
help
=
'random seed'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
None
,
help
=
'random seed'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
,
help
=
'whether to set deterministic options for CUDNN backend.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--launcher'
,
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
...
@@ -88,8 +92,9 @@ def main():
...
@@ -88,8 +92,9 @@ def main():
# set random seeds
# set random seeds
if
args
.
seed
is
not
None
:
if
args
.
seed
is
not
None
:
logger
.
info
(
'Set random seed to {}'
.
format
(
args
.
seed
))
logger
.
info
(
'Set random seed to {}, deterministic: {}'
.
format
(
set_random_seed
(
args
.
seed
)
args
.
seed
,
args
.
deterministic
))
set_random_seed
(
args
.
seed
,
deterministic
=
args
.
deterministic
)
model
=
build_detector
(
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment