Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
04001e76
Unverified
Commit
04001e76
authored
Apr 03, 2021
by
Shruti Bhosale
Committed by
GitHub
Apr 03, 2021
Browse files
[FSDP] Add gradient predivide factor to avoid overflow/underflow with large world size (#565)
parent
5a3df0da
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
2 deletions
+13
-2
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+13
-2
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
04001e76
...
...
@@ -187,6 +187,8 @@ class FullyShardedDataParallel(nn.Module):
self
.
buffer_dtype
=
buffer_dtype
or
self
.
compute_dtype
self
.
move_grads_to_cpu
=
cpu_offload
if
move_grads_to_cpu
is
None
else
move_grads_to_cpu
self
.
bucket_cap_mb
=
bucket_cap_mb
self
.
gradient_predivide_factor
:
int
=
self
.
get_gradient_predivide_factor
(
self
.
world_size
)
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
self
.
numel_padded_per_param
:
List
[
int
]
=
[]
self
.
compute_device
=
compute_device
...
...
@@ -252,6 +254,12 @@ class FullyShardedDataParallel(nn.Module):
# user explicitly requests the local state dict via local_state_dict().
self
.
_return_full_state_dict
=
True
def
get_gradient_predivide_factor
(
self
,
world_size
:
int
)
->
int
:
factor
=
1
while
world_size
%
factor
==
0
and
world_size
/
factor
>
factor
:
factor
=
factor
*
2
return
factor
@
property
def
module
(
self
)
->
nn
.
Module
:
return
self
.
_fsdp_wrapped_module
# note: may be a FlattenParamsWrapper instance
...
...
@@ -1069,9 +1077,9 @@ class FullyShardedDataParallel(nn.Module):
# Cast grad to FP32.
param
.
grad
.
data
=
param
.
grad
.
data
.
to
(
param
.
dtype
)
if
self
.
world_size
>
1
:
if
self
.
gradient_predivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
param
.
grad
.
data
.
div_
(
self
.
world_size
)
param
.
grad
.
data
.
div_
(
self
.
gradient_predivide_factor
)
callback_fn
=
functools
.
partial
(
self
.
_post_reduction_hook
,
param
)
if
param
.
_is_sharded
:
...
...
@@ -1099,6 +1107,9 @@ class FullyShardedDataParallel(nn.Module):
assert
param
.
grad
is
not
None
self
.
assert_state
(
TrainingState
.
BACKWARD_POST
)
param
.
grad
.
data
=
reduced_grad
if
self
.
gradient_postdivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
param
.
grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment