Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
a500a646
Unverified
Commit
a500a646
authored
Feb 12, 2020
by
Cao Yuhang
Committed by
GitHub
Feb 12, 2020
Browse files
Use official ddp to implement MMDDP (#184)
parent
89606a14
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
46 deletions
+56
-46
mmcv/parallel/distributed.py
mmcv/parallel/distributed.py
+2
-46
mmcv/parallel/distributed_deprecated.py
mmcv/parallel/distributed_deprecated.py
+54
-0
No files found.
mmcv/parallel/distributed.py
View file @
a500a646
# Copyright (c) Open-MMLab. All rights reserved.
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch._utils
import
(
_flatten_dense_tensors
,
_take_tensors
,
_unflatten_dense_tensors
)
from
torch.nn.parallel
import
DistributedDataParallel
from
.scatter_gather
import
scatter_kwargs
class
MMDistributedDataParallel
(
nn
.
Module
):
def
__init__
(
self
,
module
,
dim
=
0
,
broadcast_buffers
=
True
,
bucket_cap_mb
=
25
):
super
(
MMDistributedDataParallel
,
self
).
__init__
()
self
.
module
=
module
self
.
dim
=
dim
self
.
broadcast_buffers
=
broadcast_buffers
self
.
broadcast_bucket_size
=
bucket_cap_mb
*
1024
*
1024
self
.
_sync_params
()
def
_dist_broadcast_coalesced
(
self
,
tensors
,
buffer_size
):
for
tensors
in
_take_tensors
(
tensors
,
buffer_size
):
flat_tensors
=
_flatten_dense_tensors
(
tensors
)
dist
.
broadcast
(
flat_tensors
,
0
)
for
tensor
,
synced
in
zip
(
tensors
,
_unflatten_dense_tensors
(
flat_tensors
,
tensors
)):
tensor
.
copy_
(
synced
)
def
_sync_params
(
self
):
module_states
=
list
(
self
.
module
.
state_dict
().
values
())
if
len
(
module_states
)
>
0
:
self
.
_dist_broadcast_coalesced
(
module_states
,
self
.
broadcast_bucket_size
)
if
self
.
broadcast_buffers
:
if
torch
.
__version__
<
'1.0'
:
buffers
=
[
b
.
data
for
b
in
self
.
module
.
_all_buffers
()]
else
:
buffers
=
[
b
.
data
for
b
in
self
.
module
.
buffers
()]
if
len
(
buffers
)
>
0
:
self
.
_dist_broadcast_coalesced
(
buffers
,
self
.
broadcast_bucket_size
)
class
MMDistributedDataParallel
(
DistributedDataParallel
):
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
[
torch
.
cuda
.
current_device
()])
return
self
.
module
(
*
inputs
[
0
],
**
kwargs
[
0
])
mmcv/parallel/distributed_deprecated.py
0 → 100644
View file @
a500a646
# Copyright (c) Open-MMLab. All rights reserved.
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch._utils
import
(
_flatten_dense_tensors
,
_take_tensors
,
_unflatten_dense_tensors
)
from
.scatter_gather
import
scatter_kwargs
class
MMDistributedDataParallel
(
nn
.
Module
):
def
__init__
(
self
,
module
,
dim
=
0
,
broadcast_buffers
=
True
,
bucket_cap_mb
=
25
):
super
(
MMDistributedDataParallel
,
self
).
__init__
()
self
.
module
=
module
self
.
dim
=
dim
self
.
broadcast_buffers
=
broadcast_buffers
self
.
broadcast_bucket_size
=
bucket_cap_mb
*
1024
*
1024
self
.
_sync_params
()
def
_dist_broadcast_coalesced
(
self
,
tensors
,
buffer_size
):
for
tensors
in
_take_tensors
(
tensors
,
buffer_size
):
flat_tensors
=
_flatten_dense_tensors
(
tensors
)
dist
.
broadcast
(
flat_tensors
,
0
)
for
tensor
,
synced
in
zip
(
tensors
,
_unflatten_dense_tensors
(
flat_tensors
,
tensors
)):
tensor
.
copy_
(
synced
)
def
_sync_params
(
self
):
module_states
=
list
(
self
.
module
.
state_dict
().
values
())
if
len
(
module_states
)
>
0
:
self
.
_dist_broadcast_coalesced
(
module_states
,
self
.
broadcast_bucket_size
)
if
self
.
broadcast_buffers
:
if
torch
.
__version__
<
'1.0'
:
buffers
=
[
b
.
data
for
b
in
self
.
module
.
_all_buffers
()]
else
:
buffers
=
[
b
.
data
for
b
in
self
.
module
.
buffers
()]
if
len
(
buffers
)
>
0
:
self
.
_dist_broadcast_coalesced
(
buffers
,
self
.
broadcast_bucket_size
)
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
[
torch
.
cuda
.
current_device
()])
return
self
.
module
(
*
inputs
[
0
],
**
kwargs
[
0
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment