Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
4402c410
Unverified
Commit
4402c410
authored
Dec 15, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Dec 15, 2020
Browse files
[cleanup] ShardedDDP - inline gatekeeper (#248)
parent
f74afebb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
11 deletions
+9
-11
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+9
-11
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
4402c410
...
@@ -15,7 +15,6 @@ from typing import Any, Callable, Generator, List, Tuple, Union
...
@@ -15,7 +15,6 @@ from typing import Any, Callable, Generator, List, Tuple, Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.autograd
import
Variable
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
...
@@ -177,11 +176,6 @@ class ShardedDataParallel(nn.Module):
...
@@ -177,11 +176,6 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback.
Either way a delayed action is necessary and is passed as a callback.
"""
"""
def
gatekeeper
()
->
None
:
# Make sure that all the asynchronous calls have concluded before moving on. Consume the futures
# and execute the delayed actions (release gradients, unroll the buckets)
Variable
.
_execution_engine
.
queue_callback
(
optimizer
.
_consume_work_handles
)
def
reduce_direct
(
*
_
:
Any
)
->
None
:
def
reduce_direct
(
*
_
:
Any
)
->
None
:
# Skip gradient reduction, do not alter status flags
# Skip gradient reduction, do not alter status flags
if
not
self
.
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
if
not
self
.
should_accumulate_grads
and
self
.
_grad_to_be_reduced
[
index
]:
...
@@ -206,9 +200,11 @@ class ShardedDataParallel(nn.Module):
...
@@ -206,9 +200,11 @@ class ShardedDataParallel(nn.Module):
)
)
)
)
# If all the reduce operations have been called, add the gatekeeper
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if
len
(
optimizer
.
work_handles
)
==
optimizer
.
_max_work_handles
:
if
len
(
optimizer
.
work_handles
)
==
optimizer
.
_max_work_handles
:
gatekeeper
()
optimizer
.
_consume_work_handles
()
# Bucket, update status, and possibly unroll the results
# Bucket, update status, and possibly unroll the results
def
reduce_bucket
(
*
_
:
Any
)
->
None
:
def
reduce_bucket
(
*
_
:
Any
)
->
None
:
...
@@ -240,9 +236,11 @@ class ShardedDataParallel(nn.Module):
...
@@ -240,9 +236,11 @@ class ShardedDataParallel(nn.Module):
)
)
)
)
# If all the reduce operations have been called, add the gatekeeper
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if
len
(
optimizer
.
work_handles
)
==
optimizer
.
_max_work_handles
:
if
len
(
optimizer
.
work_handles
)
==
optimizer
.
_max_work_handles
:
gatekeeper
()
optimizer
.
_consume_work_handles
()
return
reduce_bucket
if
should_bucket
else
reduce_direct
return
reduce_bucket
if
should_bucket
else
reduce_direct
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment