Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
05ce7971
Unverified
Commit
05ce7971
authored
Mar 08, 2021
by
Myle Ott
Committed by
GitHub
Mar 08, 2021
Browse files
[fix] FSDP: fix MoE corner case (fixes #467) (#501)
parent
02405740
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
3 deletions
+33
-3
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+1
-2
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+32
-1
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
05ce7971
...
@@ -1127,7 +1127,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -1127,7 +1127,7 @@ class FullyShardedDataParallel(nn.Module):
if
params
is
None
:
if
params
is
None
:
params
=
self
.
params
params
=
self
.
params
self
.
has_full_params
=
False
self
.
has_full_params
=
False
curren
t_stream
=
torch
.
cuda
.
current_stream
()
self
.
_streams
[
"all_gather"
].
wai
t_stream
(
torch
.
cuda
.
current_stream
()
)
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
for
p
in
params
:
for
p
in
params
:
if
not
p
.
_is_sharded
:
# e.g., world_size == 1
if
not
p
.
_is_sharded
:
# e.g., world_size == 1
...
@@ -1140,7 +1140,6 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -1140,7 +1140,6 @@ class FullyShardedDataParallel(nn.Module):
# unshard parameters, we should reuse the original Tensor
# unshard parameters, we should reuse the original Tensor
# Storage object and unshard it in-place. For now, just resize
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
# the Storage to 0 to save memory.
p
.
_full_param_padded
.
record_stream
(
current_stream
)
free_storage_
(
p
.
_full_param_padded
)
free_storage_
(
p
.
_full_param_padded
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
05ce7971
...
@@ -287,6 +287,20 @@ class TestComparisonToPyTorchDDP(DistributedTest):
...
@@ -287,6 +287,20 @@ class TestComparisonToPyTorchDDP(DistributedTest):
)
)
spawn_and_init
(
test_fn
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
([[{
"checkpoint_act"
:
False
}],
[{
"checkpoint_act"
:
True
}]],
name_func
=
rename_test
)
def
test_mixture_of_experts_with_delay_before_free
(
self
,
moe_config
):
fsdp_config
=
{
"mixed_precision"
:
True
}
test_fn
=
functools
.
partial
(
self
.
_test_identical_outputs
,
functools
.
partial
(
MixtureOfExperts
,
delay_before_free_ms
=
250
,
**
moe_config
),
fsdp_config
,
# MixtureOfExperts implements custom reduce logic, so the reference
# behavior should use that logic instead of PyTorch DDP.
ref_ddp_fn
=
self
.
_dummy_ddp_fn
,
norm_type
=
None
,
)
spawn_and_init
(
test_fn
)
def
test_mixture_of_experts_grad_clip_breaks
(
self
):
def
test_mixture_of_experts_grad_clip_breaks
(
self
):
config
=
{
"mixed_precision"
:
True
}
config
=
{
"mixed_precision"
:
True
}
test_fn
=
functools
.
partial
(
test_fn
=
functools
.
partial
(
...
@@ -760,9 +774,10 @@ class DummyDDP(nn.Module):
...
@@ -760,9 +774,10 @@ class DummyDDP(nn.Module):
class
MixtureOfExperts
(
NestedWrappedModule
):
class
MixtureOfExperts
(
NestedWrappedModule
):
def
__init__
(
self
,
group
,
wrapper_config
,
checkpoint_act
=
False
):
def
__init__
(
self
,
group
,
wrapper_config
,
checkpoint_act
=
False
,
delay_before_free_ms
=
0
):
super
().
__init__
(
group
,
wrapper_config
)
super
().
__init__
(
group
,
wrapper_config
)
self
.
group
=
group
self
.
group
=
group
self
.
delay_before_free_ms
=
delay_before_free_ms
# "expert" params are different on each rank
# "expert" params are different on each rank
torch
.
manual_seed
(
42
+
group
.
rank
())
torch
.
manual_seed
(
42
+
group
.
rank
())
...
@@ -787,6 +802,22 @@ class MixtureOfExperts(NestedWrappedModule):
...
@@ -787,6 +802,22 @@ class MixtureOfExperts(NestedWrappedModule):
self
.
module
=
nn
.
Sequential
(
nn
.
Linear
(
8
,
4
),
shared
,
expert
,
nn
.
Linear
(
4
,
8
))
self
.
module
=
nn
.
Sequential
(
nn
.
Linear
(
8
,
4
),
shared
,
expert
,
nn
.
Linear
(
4
,
8
))
def
forward
(
self
,
x
):
if
self
.
delay_before_free_ms
>
0
:
expert
=
self
.
module
[
2
]
if
isinstance
(
expert
,
FullyShardedDataParallel
):
orig_free_full_params
=
self
.
module
[
2
].
_free_full_params
def
_free_full_params_with_delay
(
*
args
):
torch
.
cuda
.
_sleep
(
int
(
self
.
delay_before_free_ms
*
get_cycles_per_ms
()))
return
orig_free_full_params
(
*
args
)
assert
hasattr
(
expert
,
"_free_full_params"
)
with
mock
.
patch
.
object
(
expert
,
"_free_full_params"
,
_free_full_params_with_delay
):
return
self
.
module
(
x
)
return
self
.
module
(
x
)
def
run_backward
(
self
,
loss
):
def
run_backward
(
self
,
loss
):
loss
.
backward
()
loss
.
backward
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment