Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3a08d827
Commit
3a08d827
authored
May 01, 2018
by
Christian Sarofeen
Browse files
Distributed refactor.
parent
dd05dbdb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
39 deletions
+34
-39
apex/parallel/distributed.py
apex/parallel/distributed.py
+34
-39
No files found.
apex/parallel/distributed.py
View file @
3a08d827
...
@@ -4,6 +4,7 @@ import torch.distributed as dist
...
@@ -4,6 +4,7 @@ import torch.distributed as dist
from
torch.nn.modules
import
Module
from
torch.nn.modules
import
Module
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
def
flat_dist_call
(
tensors
,
call
,
extra_args
=
None
):
def
flat_dist_call
(
tensors
,
call
,
extra_args
=
None
):
flat_dist_call
.
warn_on_half
=
True
flat_dist_call
.
warn_on_half
=
True
buckets
=
{}
buckets
=
{}
...
@@ -41,17 +42,17 @@ class DistributedDataParallel(Module):
...
@@ -41,17 +42,17 @@ class DistributedDataParallel(Module):
and that torch.set_device has been used to set the device. Parameters are broadcasted
and that torch.set_device has been used to set the device. Parameters are broadcasted
to the other processes on initialization of DistributedDataParallel, and will be
to the other processes on initialization of DistributedDataParallel, and will be
allreduced in buckets durring the backward pass.
allreduced in buckets durring the backward pass.
See https://github.com/csarofeen/examples/tree/apex/distributed for detailed usage.
See https://github.com/csarofeen/examples/tree/apex/distributed for detailed usage.
Args:
Args:
module: Network definition to be run in multi-gpu/distributed mode.
module: Network definition to be run in multi-gpu/distributed mode.
message_size (Default = 100
00000
): Minimum number of elements in a communication bucket.
message_size (Default = 100
e6
): Minimum number of elements in a communication bucket.
"""
"""
def
__init__
(
self
,
module
,
message_size
=
10000000
):
def
__init__
(
self
,
module
,
message_size
=
10000000
0
):
super
(
DistributedDataParallel
,
self
).
__init__
()
super
(
DistributedDataParallel
,
self
).
__init__
()
self
.
warn_on_half
=
True
if
dist
.
_backend
==
dist
.
dist_backend
.
GLOO
else
False
self
.
warn_on_half
=
True
if
dist
.
_backend
==
dist
.
dist_backend
.
GLOO
else
False
...
@@ -93,19 +94,19 @@ class DistributedDataParallel(Module):
...
@@ -93,19 +94,19 @@ class DistributedDataParallel(Module):
if
not
self
.
needs_reduction
:
if
not
self
.
needs_reduction
:
return
return
self
.
needs_reduction
=
False
self
.
needs_reduction
=
False
r
e
ad
y
=
[]
g
rad
s
=
[]
for
i
in
range
(
len
(
self
.
param_state
)):
for
i
in
range
(
self
.
ready_end
,
len
(
self
.
param_state
)):
if
self
.
param_
state
[
i
]
==
1
:
param
=
self
.
param_
refs
[
self
.
record
[
i
]]
param
=
self
.
param_list
[
self
.
record
[
i
]]
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
grads
.
append
(
param
.
grad
.
data
)
ready
.
append
(
param
.
grad
.
data
)
grads
=
[
param
.
grad
.
data
for
param
in
self
.
ready_params
]
+
grads
if
(
len
(
r
e
ad
y
)
>
0
):
if
(
len
(
g
rad
s
)
>
0
):
orig_stream
=
torch
.
cuda
.
current_stream
()
orig_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
self
.
reduction_stream
):
with
torch
.
cuda
.
stream
(
self
.
reduction_stream
):
self
.
reduction_stream
.
wait_stream
(
orig_stream
)
self
.
reduction_stream
.
wait_stream
(
orig_stream
)
flat_dist_call
(
r
e
ad
y
,
dist
.
all_reduce
)
flat_dist_call
(
g
rad
s
,
dist
.
all_reduce
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
reduction_stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
reduction_stream
)
...
@@ -129,30 +130,25 @@ class DistributedDataParallel(Module):
...
@@ -129,30 +130,25 @@ class DistributedDataParallel(Module):
def
comm_ready_buckets
(
self
):
def
comm_ready_buckets
(
self
):
ready
=
[]
if
self
.
param_state
[
self
.
ready_end
]
==
0
:
counter
=
0
return
while
counter
<
len
(
self
.
param_state
)
and
self
.
param_state
[
counter
]
==
2
:
counter
+=
1
while
counter
<
len
(
self
.
param_state
)
and
self
.
param_state
[
counter
]
==
1
:
while
self
.
ready_end
<
len
(
self
.
param_state
)
and
self
.
param_state
[
self
.
ready_end
]
==
1
:
ready
.
append
(
counter
)
self
.
ready_params
.
append
(
self
.
param_refs
[
self
.
record
[
self
.
ready_end
]])
counter
+=
1
self
.
ready_numel
+=
self
.
ready_params
[
-
1
].
numel
()
self
.
ready_end
+=
1
if
not
ready
:
return
grads
=
[]
if
self
.
ready_numel
<
self
.
message_size
:
for
ind
in
ready
:
return
param_ind
=
self
.
record
[
ind
]
if
self
.
param_list
[
param_ind
].
grad
is
not
None
:
grads
=
[
param
.
grad
.
data
for
param
in
self
.
ready_params
]
grads
.
append
(
self
.
param_list
[
param_ind
].
grad
.
data
)
bucket
=
[]
bucket
=
[]
bucket_inds
=
[]
bucket_inds
=
[]
while
grads
:
while
grads
:
bucket
.
append
(
grads
.
pop
(
0
))
bucket
.
append
(
grads
.
pop
(
0
))
bucket_inds
.
append
(
ready
.
pop
(
0
))
cumm_size
=
0
cumm_size
=
0
for
ten
in
bucket
:
for
ten
in
bucket
:
...
@@ -168,17 +164,11 @@ class DistributedDataParallel(Module):
...
@@ -168,17 +164,11 @@ class DistributedDataParallel(Module):
with
torch
.
cuda
.
stream
(
self
.
reduction_stream
):
with
torch
.
cuda
.
stream
(
self
.
reduction_stream
):
flat_dist_call
(
bucket
,
dist
.
all_reduce
)
flat_dist_call
(
bucket
,
dist
.
all_reduce
)
for
ind
in
bucket_inds
:
for
i
in
range
(
self
.
ready_start
,
self
.
ready_start
+
len
(
bucket
)):
self
.
param_state
[
ind
]
=
2
self
.
param_state
[
i
]
=
2
self
.
ready_params
.
pop
(
0
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
"""
Forward function for DDP.
Args:
inputs: inputs that match the module's passed in for initialization.
kwargs: kwargs that match the module's passed in for initialization.
"""
param_list
=
[
param
for
param
in
list
(
self
.
module
.
parameters
())
if
param
.
requires_grad
]
param_list
=
[
param
for
param
in
list
(
self
.
module
.
parameters
())
if
param
.
requires_grad
]
...
@@ -195,5 +185,10 @@ class DistributedDataParallel(Module):
...
@@ -195,5 +185,10 @@ class DistributedDataParallel(Module):
self
.
param_state
=
[
0
for
i
in
range
(
len
(
param_list
))]
self
.
param_state
=
[
0
for
i
in
range
(
len
(
param_list
))]
self
.
param_refs
=
param_list
self
.
param_refs
=
param_list
self
.
needs_reduction
=
True
self
.
needs_reduction
=
True
self
.
ready_start
=
0
self
.
ready_end
=
0
self
.
ready_params
=
[]
self
.
ready_numel
=
0
return
self
.
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment