Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
4402c410
Unverified
Commit
4402c410
authored
Dec 15, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Dec 15, 2020
Browse files
[cleanup] ShardedDDP - inline gatekeeper (#248)
parent
f74afebb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
11 deletions
+9
-11
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+9
-11
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
4402c410
...
...
@@ -15,7 +15,6 @@ from typing import Any, Callable, Generator, List, Tuple, Union
import
torch
from
torch
import
nn
from
torch.autograd
import
Variable
import
torch.distributed
as
dist
from
torch.nn
import
Parameter
...
...
@@ -177,11 +176,6 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback.
"""
def
gatekeeper
()
->
None
:
# Make sure that all the asynchronous calls have concluded before moving on. Consume the futures
# and execute the delayed actions (release gradients, unroll the buckets)
Variable
.
_execution_engine
.
queue_callback
(
optimizer
.
_consume_work_handles
)
def
reduce_direct
(
*
_
:
Any
)
->
None
:
# Skip gradient reduction, do not alter status flags
if
not
self
.
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
...
...
@@ -206,9 +200,11 @@ class ShardedDataParallel(nn.Module):
)
)
# If all the reduce operations have been called, add the gatekeeper
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if
len
(
optimizer
.
work_handles
)
==
optimizer
.
_max_work_handles
:
gatekeeper
()
optimizer
.
_consume_work_handles
()
# Bucket, update status, and possibly unroll the results
def
reduce_bucket
(
*
_
:
Any
)
->
None
:
...
...
@@ -240,9 +236,11 @@ class ShardedDataParallel(nn.Module):
)
)
# If all the reduce operations have been called, add the gatekeeper
if
len
(
optimizer
.
work_handles
)
==
optimizer
.
_max_work_handles
:
gatekeeper
()
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if
len
(
optimizer
.
work_handles
)
==
optimizer
.
_max_work_handles
:
optimizer
.
_consume_work_handles
()
return
reduce_bucket
if
should_bucket
else
reduce_direct
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment