Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SOLOv2-pytorch
Commits
904d875a
Commit
904d875a
authored
Sep 24, 2018
by
Kai Chen
Browse files
modify distributed training api and use coalesced all_reduce
parent
15e9d026
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
93 additions
and
56 deletions
+93
-56
mmdet/core/utils/dist_utils.py
mmdet/core/utils/dist_utils.py
+60
-22
mmdet/datasets/loader/build_loader.py
mmdet/datasets/loader/build_loader.py
+2
-2
tools/configs/r50_fpn_frcnn_1x.py
tools/configs/r50_fpn_frcnn_1x.py
+1
-2
tools/configs/r50_fpn_maskrcnn_1x.py
tools/configs/r50_fpn_maskrcnn_1x.py
+1
-2
tools/configs/r50_fpn_rpn_1x.py
tools/configs/r50_fpn_rpn_1x.py
+1
-2
tools/train.py
tools/train.py
+28
-26
No files found.
mmdet/core/utils/dist_utils.py
View file @
904d875a
import
os
import
os
from
collections
import
OrderedDict
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch.nn.utils
import
clip_grad
from
torch.nn.utils
import
clip_grad
from
mmcv.torchpack
import
Hook
,
OptimizerHook
from
mmcv.torchpack
import
Hook
,
OptimizerHook
__all__
=
[
__all__
=
[
'init_dist'
,
'average_gradients'
,
'broadcast_params'
,
'DistOptimizerHook'
,
'init_dist'
,
'reduce_grads'
,
'DistOptimizerHook'
,
'DistSamplerSeedHook'
'DistSamplerSeedHook'
]
]
def
init_dist
(
world_size
,
def
init_dist
(
launcher
,
backend
=
'nccl'
,
**
kwargs
):
rank
,
backend
=
'gloo'
,
master_ip
=
'127.0.0.1'
,
port
=
29500
):
if
mp
.
get_start_method
(
allow_none
=
True
)
is
None
:
if
mp
.
get_start_method
(
allow_none
=
True
)
is
None
:
mp
.
set_start_method
(
'spawn'
)
mp
.
set_start_method
(
'spawn'
)
if
launcher
==
'pytorch'
:
_init_dist_pytorch
(
backend
,
**
kwargs
)
elif
launcher
==
'mpi'
:
_init_dist_pytorch
(
backend
,
**
kwargs
)
elif
launcher
==
'slurm'
:
_init_dist_pytorch
(
backend
,
**
kwargs
)
else
:
raise
ValueError
(
'Invalid launcher type: {}'
.
format
(
launcher
))
def
_init_dist_pytorch
(
backend
,
**
kwargs
):
# TODO: use local_rank instead of rank % num_gpus
rank
=
int
(
os
.
environ
[
'RANK'
])
num_gpus
=
torch
.
cuda
.
device_count
()
num_gpus
=
torch
.
cuda
.
device_count
()
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
torch
.
cuda
.
set_device
(
rank
%
num_gpus
)
os
.
environ
[
'MASTER_ADDR'
]
=
master_ip
dist
.
init_process_group
(
backend
=
backend
,
**
kwargs
)
os
.
environ
[
'MASTER_PORT'
]
=
str
(
port
)
if
backend
==
'nccl'
:
dist
.
init_process_group
(
backend
=
'nccl'
)
def
_init_dist_mpi
(
backend
,
**
kwargs
):
else
:
raise
NotImplementedError
dist
.
init_process_group
(
backend
=
'gloo'
,
rank
=
rank
,
world_size
=
world_size
)
def
average_gradients
(
model
):
def
_init_dist_slurm
(
backend
,
**
kwargs
):
for
param
in
model
.
parameters
():
raise
NotImplementedError
if
param
.
requires_grad
and
not
(
param
.
grad
is
None
):
dist
.
all_reduce
(
param
.
grad
.
data
)
def
broadcast_params
(
model
):
# modified from https://github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py#L9
for
p
in
model
.
state_dict
().
values
():
def
coalesce_all_reduce
(
tensors
):
dist
.
broadcast
(
p
,
0
)
buckets
=
OrderedDict
()
for
tensor
in
tensors
:
tp
=
tensor
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
tensor
)
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
coalesced
=
_flatten_dense_tensors
(
bucket
)
dist
.
all_reduce
(
coalesced
)
coalesced
/=
dist
.
get_world_size
()
for
buf
,
synced
in
zip
(
bucket
,
_unflatten_dense_tensors
(
coalesced
,
bucket
)):
buf
.
copy_
(
synced
)
def
reduce_grads
(
model
,
coalesce
=
True
):
grads
=
[
param
.
grad
.
data
for
param
in
model
.
parameters
()
if
param
.
requires_grad
and
param
.
grad
is
not
None
]
if
coalesce
:
coalesce_all_reduce
(
grads
)
else
:
for
tensor
in
grads
:
dist
.
all_reduce
(
tensor
)
class
DistOptimizerHook
(
OptimizerHook
):
class
DistOptimizerHook
(
OptimizerHook
):
def
__init__
(
self
,
grad_clip
=
None
,
coalesce
=
True
):
self
.
grad_clip
=
grad_clip
self
.
coalesce
=
coalesce
def
after_train_iter
(
self
,
runner
):
def
after_train_iter
(
self
,
runner
):
runner
.
optimizer
.
zero_grad
()
runner
.
optimizer
.
zero_grad
()
runner
.
outputs
[
'loss'
].
backward
()
runner
.
outputs
[
'loss'
].
backward
()
averag
e_grad
ient
s
(
runner
.
model
)
reduc
e_grads
(
runner
.
model
,
self
.
coalesce
)
if
self
.
grad_clip
is
not
None
:
if
self
.
grad_clip
is
not
None
:
clip_grad
.
clip_grad_norm_
(
clip_grad
.
clip_grad_norm_
(
filter
(
lambda
p
:
p
.
requires_grad
,
runner
.
model
.
parameters
()),
filter
(
lambda
p
:
p
.
requires_grad
,
runner
.
model
.
parameters
()),
...
...
mmdet/datasets/loader/build_loader.py
View file @
904d875a
from
functools
import
partial
from
functools
import
partial
from
mmcv.torchpack
import
get_dist_info
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
.collate
import
collate
from
.collate
import
collate
...
@@ -11,10 +12,9 @@ def build_dataloader(dataset,
...
@@ -11,10 +12,9 @@ def build_dataloader(dataset,
workers_per_gpu
,
workers_per_gpu
,
num_gpus
,
num_gpus
,
dist
=
True
,
dist
=
True
,
world_size
=
1
,
rank
=
0
,
**
kwargs
):
**
kwargs
):
if
dist
:
if
dist
:
rank
,
world_size
=
get_dist_info
()
sampler
=
DistributedGroupSampler
(
dataset
,
imgs_per_gpu
,
world_size
,
sampler
=
DistributedGroupSampler
(
dataset
,
imgs_per_gpu
,
world_size
,
rank
)
rank
)
batch_size
=
imgs_per_gpu
batch_size
=
imgs_per_gpu
...
...
tools/configs/r50_fpn_frcnn_1x.py
View file @
904d875a
...
@@ -121,8 +121,7 @@ log_config = dict(
...
@@ -121,8 +121,7 @@ log_config = dict(
# yapf:enable
# yapf:enable
# runtime settings
# runtime settings
total_epochs
=
12
total_epochs
=
12
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
)
dist_params
=
dict
(
backend
=
'nccl'
,
port
=
'29500'
,
master_ip
=
'127.0.0.1'
)
log_level
=
'INFO'
log_level
=
'INFO'
work_dir
=
'./work_dirs/fpn_faster_rcnn_r50_1x'
work_dir
=
'./work_dirs/fpn_faster_rcnn_r50_1x'
load_from
=
None
load_from
=
None
...
...
tools/configs/r50_fpn_maskrcnn_1x.py
View file @
904d875a
...
@@ -134,8 +134,7 @@ log_config = dict(
...
@@ -134,8 +134,7 @@ log_config = dict(
# yapf:enable
# yapf:enable
# runtime settings
# runtime settings
total_epochs
=
12
total_epochs
=
12
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'nccl'
)
dist_params
=
dict
(
backend
=
'nccl'
,
port
=
'29500'
,
master_ip
=
'127.0.0.1'
)
log_level
=
'INFO'
log_level
=
'INFO'
work_dir
=
'./work_dirs/fpn_mask_rcnn_r50_1x'
work_dir
=
'./work_dirs/fpn_mask_rcnn_r50_1x'
load_from
=
None
load_from
=
None
...
...
tools/configs/r50_fpn_rpn_1x.py
View file @
904d875a
...
@@ -100,8 +100,7 @@ log_config = dict(
...
@@ -100,8 +100,7 @@ log_config = dict(
# yapf:enable
# yapf:enable
# runtime settings
# runtime settings
total_epochs
=
12
total_epochs
=
12
device_ids
=
range
(
8
)
dist_params
=
dict
(
backend
=
'gloo'
)
dist_params
=
dict
(
backend
=
'gloo'
,
port
=
'29500'
,
master_ip
=
'127.0.0.1'
)
log_level
=
'INFO'
log_level
=
'INFO'
work_dir
=
'./work_dirs/fpn_rpn_r50_1x'
work_dir
=
'./work_dirs/fpn_rpn_r50_1x'
load_from
=
None
load_from
=
None
...
...
tools/train.py
View file @
904d875a
...
@@ -39,9 +39,7 @@ def batch_processor(model, data, train_mode):
...
@@ -39,9 +39,7 @@ def batch_processor(model, data, train_mode):
loss
,
log_vars
=
parse_losses
(
losses
)
loss
,
log_vars
=
parse_losses
(
losses
)
outputs
=
dict
(
outputs
=
dict
(
loss
=
loss
/
args
.
world_size
,
loss
=
loss
,
log_vars
=
log_vars
,
num_samples
=
len
(
data
[
'img'
].
data
))
log_vars
=
log_vars
,
num_samples
=
len
(
data
[
'img'
].
data
))
return
outputs
return
outputs
...
@@ -54,61 +52,65 @@ def parse_args():
...
@@ -54,61 +52,65 @@ def parse_args():
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'whether to add a validate phase'
)
help
=
'whether to add a validate phase'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--dist'
,
action
=
'store_true'
,
help
=
'use distributed training or not'
)
'--gpus'
,
type
=
int
,
default
=
1
,
help
=
'number of gpus to use'
)
parser
.
add_argument
(
'--world-size'
,
default
=
1
,
type
=
int
)
parser
.
add_argument
(
parser
.
add_argument
(
'--rank'
,
default
=
0
,
type
=
int
)
'--launcher'
,
choices
=
[
'none'
,
'pytorch'
,
'slurm'
,
'mpi'
],
default
=
'none'
,
help
=
'job launcher'
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
args
=
parse_args
()
def
main
():
def
main
():
# get config from file
args
=
parse_args
()
cfg
=
Config
.
fromfile
(
args
.
config
)
cfg
=
Config
.
fromfile
(
args
.
config
)
cfg
.
update
(
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
cfg
.
update
(
gpus
=
args
.
gpus
)
# init distributed environment if necessary
# init distributed environment if necessary
if
args
.
dist
:
if
args
.
launcher
==
'none'
:
print
(
'Enable distributed training.'
)
dist
=
False
init_dist
(
args
.
world_size
,
args
.
rank
,
**
cfg
.
dist_params
)
else
:
print
(
'Disabled distributed training.'
)
print
(
'Disabled distributed training.'
)
else
:
dist
=
True
print
(
'Enabled distributed training.'
)
init_dist
(
args
.
launcher
,
**
cfg
.
dist_args
)
# prepare data loaders
# prepare data loaders
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
train_dataset
=
obj_from_dict
(
cfg
.
data
.
train
,
datasets
)
data_loaders
=
[
data_loaders
=
[
build_dataloader
(
build_dataloader
(
train_dataset
,
cfg
.
data
.
imgs_per_gpu
,
train_dataset
,
cfg
.
data
.
imgs_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
cfg
.
gpus
,
dist
)
len
(
cfg
.
device_ids
),
args
.
dist
,
cfg
.
world_size
,
cfg
.
rank
)
]
]
if
args
.
validate
:
if
args
.
validate
:
val_dataset
=
obj_from_dict
(
cfg
.
data
.
val
,
datasets
)
val_dataset
=
obj_from_dict
(
cfg
.
data
.
val
,
datasets
)
data_loaders
.
append
(
data_loaders
.
append
(
build_dataloader
(
build_dataloader
(
val_dataset
,
cfg
.
data
.
imgs_per_gpu
,
val_dataset
,
cfg
.
data
.
imgs_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
cfg
.
data
.
workers_per_gpu
,
cfg
.
gpus
,
dist
))
len
(
cfg
.
device_ids
),
args
.
dist
,
cfg
.
world_size
,
cfg
.
rank
))
# build model
# build model
model
=
build_detector
(
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
if
args
.
dist
:
if
dist
:
model
=
MMDistributedDataParallel
(
model
=
MMDistributedDataParallel
(
model
,
device_ids
=
[
cfg
.
rank
],
broadcast_buffers
=
False
).
cuda
()
model
,
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
).
cuda
()
else
:
else
:
model
=
MMDataParallel
(
model
,
device_ids
=
cfg
.
device_ids
).
cuda
()
model
=
MMDataParallel
(
model
,
device_ids
=
range
(
cfg
.
gpus
)
).
cuda
()
# build runner
# build runner
runner
=
Runner
(
model
,
batch_processor
,
cfg
.
optimizer
,
cfg
.
work_dir
,
runner
=
Runner
(
model
,
batch_processor
,
cfg
.
optimizer
,
cfg
.
work_dir
,
cfg
.
log_level
)
cfg
.
log_level
)
# register hooks
# register hooks
optimizer_config
=
DistOptimizerHook
(
optimizer_config
=
DistOptimizerHook
(
**
cfg
.
optimizer_config
)
if
args
.
dist
else
cfg
.
optimizer_config
**
cfg
.
optimizer_config
)
if
dist
else
cfg
.
optimizer_config
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
runner
.
register_training_hooks
(
cfg
.
lr_config
,
optimizer_config
,
cfg
.
checkpoint_config
,
cfg
.
log_config
)
cfg
.
checkpoint_config
,
cfg
.
log_config
)
if
args
.
dist
:
if
dist
:
runner
.
register_hook
(
DistSamplerSeedHook
())
runner
.
register_hook
(
DistSamplerSeedHook
())
if
cfg
.
resume_from
:
if
cfg
.
resume_from
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment