Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
142cfdcc
Unverified
Commit
142cfdcc
authored
Mar 17, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 17, 2021
Browse files
[refactor] removing dead or faulty code (#530)
parent
98223763
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
20 deletions
+1
-20
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+1
-20
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
142cfdcc
...
@@ -13,13 +13,12 @@ import contextlib
...
@@ -13,13 +13,12 @@ import contextlib
import
functools
import
functools
from
itertools
import
chain
from
itertools
import
chain
import
logging
import
logging
from
typing
import
Any
,
Callable
,
Deque
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Deque
,
Dict
,
Generator
,
List
,
Optional
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.nn
import
Parameter
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
from
fairscale.optim.utils
import
Bucket
,
Workhandle
from
fairscale.optim.utils
import
Bucket
,
Workhandle
...
@@ -367,33 +366,15 @@ class ShardedDataParallel(nn.Module):
...
@@ -367,33 +366,15 @@ class ShardedDataParallel(nn.Module):
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_bucket_flush_callback_set
=
False
self
.
_bucket_flush_callback_set
=
False
# Do not reset the buckets
if
self
.
use_buckets
:
if
self
.
use_buckets
:
assert
self
.
_bucket_list
is
not
None
assert
self
.
_bucket_list
is
not
None
for
bucket
in
self
.
_bucket_list
:
for
bucket
in
self
.
_bucket_list
:
assert
(
self
.
accumulate_grads_flipped
or
not
self
.
training
or
self
.
should_accumulate_grads
or
bucket
.
sent
),
(
"A bucket failed to be sent, cannot continue as results would be wrong. "
+
"You can trye de-activating ShardedDDP buckets -set `reduce_buffer_size` to 0-"
+
"Please submit a GitHub issue, this should not happen"
)
bucket
.
reset
()
bucket
.
reset
()
if
not
self
.
should_accumulate_grads
:
if
not
self
.
should_accumulate_grads
:
self
.
accumulate_grads_flipped
=
False
self
.
accumulate_grads_flipped
=
False
def
_find_rank
(
self
,
param
:
Parameter
)
->
Tuple
[
OSS
,
int
]:
""" Look up where this parameter belongs to """
for
optim
in
self
.
sharded_optimizers
:
if
param
in
optim
.
param_to_rank
.
keys
():
return
optim
,
optim
.
param_to_rank
[
param
]
assert
False
,
"This parameter is not present in an optimizer, this should not happen"
return
(
None
,
-
1
)
def
_get_reduce_fn
(
self
,
index
:
int
,
param
:
torch
.
Tensor
,
dst_rank
:
int
)
->
Callable
:
def
_get_reduce_fn
(
self
,
index
:
int
,
param
:
torch
.
Tensor
,
dst_rank
:
int
)
->
Callable
:
"""
"""
Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank,
Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment