Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3a08d827
Commit
3a08d827
authored
May 01, 2018
by
Christian Sarofeen
Browse files
Distributed refactor.
parent
dd05dbdb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
39 deletions
+34
-39
apex/parallel/distributed.py
apex/parallel/distributed.py
+34
-39
No files found.
apex/parallel/distributed.py
View file @
3a08d827
...
@@ -4,6 +4,7 @@ import torch.distributed as dist
...
@@ -4,6 +4,7 @@ import torch.distributed as dist
from
torch.nn.modules
import
Module
from
torch.nn.modules
import
Module
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
def
flat_dist_call
(
tensors
,
call
,
extra_args
=
None
):
def
flat_dist_call
(
tensors
,
call
,
extra_args
=
None
):
flat_dist_call
.
warn_on_half
=
True
flat_dist_call
.
warn_on_half
=
True
buckets
=
{}
buckets
=
{}
...
@@ -46,12 +47,12 @@ class DistributedDataParallel(Module):
...
@@ -46,12 +47,12 @@ class DistributedDataParallel(Module):
Args:
Args:
module: Network definition to be run in multi-gpu/distributed mode.
module: Network definition to be run in multi-gpu/distributed mode.
message_size (Default = 100
00000
): Minimum number of elements in a communication bucket.
message_size (Default = 100
e6
): Minimum number of elements in a communication bucket.
"""
"""
def
__init__
(
self
,
module
,
message_size
=
10000000
):
def
__init__
(
self
,
module
,
message_size
=
10000000
0
):
super
(
DistributedDataParallel
,
self
).
__init__
()
super
(
DistributedDataParallel
,
self
).
__init__
()
self
.
warn_on_half
=
True
if
dist
.
_backend
==
dist
.
dist_backend
.
GLOO
else
False
self
.
warn_on_half
=
True
if
dist
.
_backend
==
dist
.
dist_backend
.
GLOO
else
False
...
@@ -94,18 +95,18 @@ class DistributedDataParallel(Module):
...
@@ -94,18 +95,18 @@ class DistributedDataParallel(Module):
return
return
self
.
needs_reduction
=
False
self
.
needs_reduction
=
False
ready
=
[]
grads
=
[]
for
i
in
range
(
len
(
self
.
param_state
)):
for
i
in
range
(
self
.
ready_end
,
len
(
self
.
param_state
)):
if
self
.
param_state
[
i
]
==
1
:
param
=
self
.
param_refs
[
self
.
record
[
i
]]
param
=
self
.
param_list
[
self
.
record
[
i
]]
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
ready
.
append
(
param
.
grad
.
data
)
grads
.
append
(
param
.
grad
.
data
)
grads
=
[
param
.
grad
.
data
for
param
in
self
.
ready_params
]
+
grads
if
(
len
(
r
e
ad
y
)
>
0
):
if
(
len
(
g
rad
s
)
>
0
):
orig_stream
=
torch
.
cuda
.
current_stream
()
orig_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
self
.
reduction_stream
):
with
torch
.
cuda
.
stream
(
self
.
reduction_stream
):
self
.
reduction_stream
.
wait_stream
(
orig_stream
)
self
.
reduction_stream
.
wait_stream
(
orig_stream
)
flat_dist_call
(
r
e
ad
y
,
dist
.
all_reduce
)
flat_dist_call
(
g
rad
s
,
dist
.
all_reduce
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
reduction_stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
reduction_stream
)
...
@@ -129,30 +130,25 @@ class DistributedDataParallel(Module):
...
@@ -129,30 +130,25 @@ class DistributedDataParallel(Module):
def
comm_ready_buckets
(
self
):
def
comm_ready_buckets
(
self
):
ready
=
[]
if
self
.
param_state
[
self
.
ready_end
]
==
0
:
counter
=
0
return
while
counter
<
len
(
self
.
param_state
)
and
self
.
param_state
[
counter
]
==
2
:
while
self
.
ready_end
<
len
(
self
.
param_state
)
and
self
.
param_state
[
self
.
ready_end
]
==
1
:
counter
+=
1
self
.
ready_params
.
append
(
self
.
param_refs
[
self
.
record
[
self
.
ready_end
]])
self
.
ready_numel
+=
self
.
ready_params
[
-
1
].
numel
()
self
.
ready_end
+=
1
while
counter
<
len
(
self
.
param_state
)
and
self
.
param_state
[
counter
]
==
1
:
ready
.
append
(
counter
)
counter
+=
1
if
not
ready
:
if
self
.
ready_numel
<
self
.
message_size
:
return
return
grads
=
[]
grads
=
[
param
.
grad
.
data
for
param
in
self
.
ready_params
]
for
ind
in
ready
:
param_ind
=
self
.
record
[
ind
]
if
self
.
param_list
[
param_ind
].
grad
is
not
None
:
grads
.
append
(
self
.
param_list
[
param_ind
].
grad
.
data
)
bucket
=
[]
bucket
=
[]
bucket_inds
=
[]
bucket_inds
=
[]
while
grads
:
while
grads
:
bucket
.
append
(
grads
.
pop
(
0
))
bucket
.
append
(
grads
.
pop
(
0
))
bucket_inds
.
append
(
ready
.
pop
(
0
))
cumm_size
=
0
cumm_size
=
0
for
ten
in
bucket
:
for
ten
in
bucket
:
...
@@ -168,17 +164,11 @@ class DistributedDataParallel(Module):
...
@@ -168,17 +164,11 @@ class DistributedDataParallel(Module):
with
torch
.
cuda
.
stream
(
self
.
reduction_stream
):
with
torch
.
cuda
.
stream
(
self
.
reduction_stream
):
flat_dist_call
(
bucket
,
dist
.
all_reduce
)
flat_dist_call
(
bucket
,
dist
.
all_reduce
)
for
ind
in
bucket_inds
:
for
i
in
range
(
self
.
ready_start
,
self
.
ready_start
+
len
(
bucket
)):
self
.
param_state
[
ind
]
=
2
self
.
param_state
[
i
]
=
2
self
.
ready_params
.
pop
(
0
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
"""
Forward function for DDP.
Args:
inputs: inputs that match the module's passed in for initialization.
kwargs: kwargs that match the module's passed in for initialization.
"""
param_list
=
[
param
for
param
in
list
(
self
.
module
.
parameters
())
if
param
.
requires_grad
]
param_list
=
[
param
for
param
in
list
(
self
.
module
.
parameters
())
if
param
.
requires_grad
]
...
@@ -196,4 +186,9 @@ class DistributedDataParallel(Module):
...
@@ -196,4 +186,9 @@ class DistributedDataParallel(Module):
self
.
param_refs
=
param_list
self
.
param_refs
=
param_list
self
.
needs_reduction
=
True
self
.
needs_reduction
=
True
self
.
ready_start
=
0
self
.
ready_end
=
0
self
.
ready_params
=
[]
self
.
ready_numel
=
0
return
self
.
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment