Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
04001e76
Unverified
Commit
04001e76
authored
Apr 03, 2021
by
Shruti Bhosale
Committed by
GitHub
Apr 03, 2021
Browse files
[FSDP] Add gradient predivide factor to avoid overflow/underflow with large world size (#565)
parent
5a3df0da
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
2 deletions
+13
-2
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+13
-2
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
04001e76
...
...
@@ -187,6 +187,8 @@ class FullyShardedDataParallel(nn.Module):
self
.
buffer_dtype
=
buffer_dtype
or
self
.
compute_dtype
self
.
move_grads_to_cpu
=
cpu_offload
if
move_grads_to_cpu
is
None
else
move_grads_to_cpu
self
.
bucket_cap_mb
=
bucket_cap_mb
self
.
gradient_predivide_factor
:
int
=
self
.
get_gradient_predivide_factor
(
self
.
world_size
)
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
self
.
numel_padded_per_param
:
List
[
int
]
=
[]
self
.
compute_device
=
compute_device
...
...
@@ -252,6 +254,12 @@ class FullyShardedDataParallel(nn.Module):
# user explicitly requests the local state dict via local_state_dict().
self
.
_return_full_state_dict
=
True
def
get_gradient_predivide_factor
(
self
,
world_size
:
int
)
->
int
:
factor
=
1
while
world_size
%
factor
==
0
and
world_size
/
factor
>
factor
:
factor
=
factor
*
2
return
factor
@
property
def
module
(
self
)
->
nn
.
Module
:
return
self
.
_fsdp_wrapped_module
# note: may be a FlattenParamsWrapper instance
...
...
@@ -1069,9 +1077,9 @@ class FullyShardedDataParallel(nn.Module):
# Cast grad to FP32.
param
.
grad
.
data
=
param
.
grad
.
data
.
to
(
param
.
dtype
)
if
self
.
world_size
>
1
:
if
self
.
gradient_predivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
param
.
grad
.
data
.
div_
(
self
.
world_size
)
param
.
grad
.
data
.
div_
(
self
.
gradient_predivide_factor
)
callback_fn
=
functools
.
partial
(
self
.
_post_reduction_hook
,
param
)
if
param
.
_is_sharded
:
...
...
@@ -1099,6 +1107,9 @@ class FullyShardedDataParallel(nn.Module):
assert
param
.
grad
is
not
None
self
.
assert_state
(
TrainingState
.
BACKWARD_POST
)
param
.
grad
.
data
=
reduced_grad
if
self
.
gradient_postdivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
param
.
grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment