Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b6dc98cf
Unverified
Commit
b6dc98cf
authored
Feb 26, 2021
by
Myle Ott
Committed by
GitHub
Feb 26, 2021
Browse files
[fix] fix FSDP state_dict/load_state_dict for nested wrapped instances (#440)
parent
93d115c6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
186 additions
and
48 deletions
+186
-48
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+122
-41
fairscale/utils/testing.py
fairscale/utils/testing.py
+1
-1
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+63
-6
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
b6dc98cf
...
...
@@ -29,6 +29,7 @@ from fairscale.utils.containers import (
)
from
fairscale.utils.parallel
import
chunk_and_pad
,
validate_process_group
from
fairscale.utils.reduce_scatter_bucketer
import
ReduceScatterBucketer
from
fairscale.utils.state_dict
import
replace_by_prefix_
if
TYPE_CHECKING
:
from
collections
import
OrderedDict
# noqa: F401
...
...
@@ -172,11 +173,11 @@ class FullyShardedDataParallel(nn.Module):
params
=
list
(
p
for
p
in
module
.
parameters
()
if
not
hasattr
(
p
,
"_is_sharded"
))
if
self
.
flatten_parameters
and
len
(
params
)
>
0
:
self
.
module
:
nn
.
Module
=
FlattenParamsWrapper
(
module
,
param_list
=
params
)
self
.
_fsdp_wrapped_
module
:
nn
.
Module
=
FlattenParamsWrapper
(
module
,
param_list
=
params
)
del
module
# free original module in case it helps garbage collection
self
.
params
=
[
self
.
module
.
flat_param
]
self
.
params
=
[
self
.
_fsdp_wrapped_
module
.
flat_param
]
else
:
self
.
module
=
module
self
.
_fsdp_wrapped_
module
=
module
self
.
params
=
params
# Shard module parameters in place
...
...
@@ -192,8 +193,23 @@ class FullyShardedDataParallel(nn.Module):
# pass. This will be False when inside the no_sync context manager.
self
.
require_backward_grad_sync
:
bool
=
True
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self
.
training_state
=
TrainingState
.
IDLE
# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back.
self
.
_register_state_dict_hook
(
_post_state_dict_hook
)
self
.
_register_load_state_dict_pre_hook
(
_pre_load_state_dict_hook
)
# Flag to indicate whether state_dict() should automatically summon the
# full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict().
self
.
_return_full_state_dict
=
True
@
property
def
module
(
self
)
->
nn
.
Module
:
return
self
.
_fsdp_wrapped_module
# note: may be a FlattenParamsWrapper instance
@
torch
.
no_grad
()
def
_all_buffers_to
(
self
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
None
:
"""Move all buffers to the specified device and dtype, recursively."""
...
...
@@ -235,6 +251,7 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
self
.
_lazy_init
()
assert
self
.
_is_root
,
"clip_grad_norm should only be called on the root (parent) instance"
assert
self
.
training_state
==
TrainingState
.
IDLE
...
...
@@ -374,7 +391,7 @@ class FullyShardedDataParallel(nn.Module):
self
.
_reset_lazy_init
()
# TODO (Min): figuring out how to do typing for this overloaded function.
def
state_dict
(
self
,
*
args
,
**
kwargs
)
:
# type: ignore
def
state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
"OrderedDict[str, torch.Tensor]"
:
# type: ignore
"""
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
...
...
@@ -384,16 +401,28 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
with
self
.
summon_full_params
()
:
if
self
.
mixed_precision
:
# Buffers dtype stays consistent with parameters.
self
.
_all_buffers_to
(
dtype
=
torch
.
float32
)
state_dict
=
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
# We copy the state_dict since full param will be freed after
# we exit the summon_full_params() context.
for
key
in
state_dict
.
keys
():
state_dict
[
key
]
=
state_dict
[
key
].
clone
()
if
self
.
_return_full_state_dict
:
if
self
.
training_state
!=
TrainingState
.
SUMMON_FULL_PARAMS
:
with
self
.
summon_full_params
():
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
if
self
.
flatten_parameters
:
assert
isinstance
(
self
.
module
,
FlattenParamsWrapper
)
state_dict
=
self
.
module
.
flat_state_dict
(
*
args
,
**
kwargs
)
else
:
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
if
self
.
mixed_precision
:
# In case we are in mixed precision, restore buffers back to fp16.
self
.
_all_buffers_to
(
dtype
=
self
.
compute_dtype
)
return
state_dict
...
...
@@ -405,12 +434,21 @@ class FullyShardedDataParallel(nn.Module):
so the resulting state_dict can only be loaded after the Module has been
wrapped with FullyShardedDataParallel.
"""
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
if
self
.
flatten_parameters
:
return
self
.
module
.
flat_state_dict
(
*
args
,
**
kwargs
)
# type: ignore
else
:
return
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
with
contextlib
.
ExitStack
()
as
stack
:
# Tell any nested FSDP instances not to auto summon full params.
for
module
in
self
.
modules
():
# includes self
if
isinstance
(
module
,
FullyShardedDataParallel
):
stack
.
enter_context
(
module
.
_no_return_full_state_dict
())
return
self
.
state_dict
(
*
args
,
**
kwargs
)
@
contextlib
.
contextmanager
def
_no_return_full_state_dict
(
self
)
->
Generator
:
backup
=
self
.
_return_full_state_dict
self
.
_return_full_state_dict
=
False
try
:
yield
finally
:
self
.
_return_full_state_dict
=
backup
def
load_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
...
...
@@ -421,16 +459,25 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
with
self
.
summon_full_params
():
output
=
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
return
output
if
self
.
_return_full_state_dict
:
with
self
.
summon_full_params
():
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
def
load_local_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
)
->
NamedTuple
:
"""Load a local (sharded) state_dict."""
torch
.
cuda
.
synchronize
()
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
with
contextlib
.
ExitStack
()
as
stack
:
# Tell any nested FSDP instances not to auto summon full params.
for
module
in
self
.
modules
():
# includes self
if
isinstance
(
module
,
FullyShardedDataParallel
):
stack
.
enter_context
(
module
.
_no_return_full_state_dict
())
output
=
self
.
load_state_dict
(
state_dict
,
strict
)
return
output
@
contextlib
.
contextmanager
def
no_sync
(
self
)
->
Generator
:
...
...
@@ -457,30 +504,44 @@ class FullyShardedDataParallel(nn.Module):
m
.
require_backward_grad_sync
=
old_flag
@
contextlib
.
contextmanager
def
summon_full_params
(
self
)
->
Generator
:
def
summon_full_params
(
self
,
recurse
:
bool
=
True
)
->
Generator
:
"""
A context manager to expose full params for the underlying model.
Can be useful *after* forward/backward for a model to get the params
for additional processing or checking.
A context manager to expose full params for the current FSDP instance.
Can be useful *after* forward/backward for a model to get the params for
additional processing or checking.
By default this will recursively summon all params for nested FSDP
instances; this can be disabled by setting ``recurse=False``.
This can be used on inner FSDPs.
.. note::
This can be used on inner FSDPs.
This can *not* be used within a forward or backward pass. Nor
can forward
and backward be started from within this context.
.. note::
This can *not* be used within a forward or backward pass. Nor
can forward
and backward be started from within this context.
"""
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
self
.
assert_state
(
TrainingState
.
IDLE
)
# Set the state so that we assert when trying to go into
# forward/backward.
self
.
training_state
=
TrainingState
.
SUMMON_FULL_PARAMS
self
.
_rebuild_full_params
()
try
:
yield
finally
:
self
.
_free_full_params
()
self
.
_use_fp32_param_shard
()
self
.
training_state
=
TrainingState
.
IDLE
if
recurse
:
with
contextlib
.
ExitStack
()
as
stack
:
# summon all params for any nested FlattenParamsWrapper instances
for
module
in
self
.
modules
():
if
isinstance
(
module
,
FullyShardedDataParallel
):
stack
.
enter_context
(
module
.
summon_full_params
(
recurse
=
False
))
# yield to the caller, with full params in all nested instances
yield
# exiting from the ExitStack will re-shard params
return
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
self
.
assert_state
(
TrainingState
.
IDLE
)
# Set the state so that we assert when trying to go into
# forward/backward.
self
.
training_state
=
TrainingState
.
SUMMON_FULL_PARAMS
self
.
_rebuild_full_params
()
try
:
yield
finally
:
self
.
_free_full_params
()
self
.
_use_fp32_param_shard
()
self
.
training_state
=
TrainingState
.
IDLE
def
_reset_lazy_init
(
self
)
->
None
:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
...
...
@@ -1011,3 +1072,23 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
return
assert
data
.
storage
().
size
()
==
0
data
.
storage
().
resize_
(
size
.
numel
())
def
_post_state_dict_hook
(
module
:
nn
.
Module
,
state_dict
:
"OrderedDict[str, torch.Tensor]"
,
prefix
:
str
,
*
args
:
Any
)
->
"OrderedDict[str, torch.Tensor]"
:
if
module
.
training_state
==
TrainingState
.
SUMMON_FULL_PARAMS
:
# We copy the state_dict since full param will be freed after
# we exit the summon_full_params() context.
for
key
in
state_dict
.
keys
():
state_dict
[
key
]
=
state_dict
[
key
].
clone
()
# Remove "_fsdp_wrapped_module." prefix
replace_by_prefix_
(
state_dict
,
prefix
+
"_fsdp_wrapped_module."
,
prefix
)
return
state_dict
def
_pre_load_state_dict_hook
(
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
prefix
:
str
,
*
args
:
Any
)
->
None
:
replace_by_prefix_
(
state_dict
,
prefix
,
prefix
+
"_fsdp_wrapped_module."
)
fairscale/utils/testing.py
View file @
b6dc98cf
...
...
@@ -415,7 +415,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
shape_dtype_device_match
=
a
.
size
()
==
b
.
size
()
and
a
.
dtype
==
b
.
dtype
and
a
.
device
==
b
.
device
assert
shape_dtype_device_match
return
True
except
AssertionError
as
e
:
except
(
AssertionError
,
RuntimeError
)
as
e
:
if
raise_exception
:
raise
e
else
:
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
b6dc98cf
...
...
@@ -434,11 +434,9 @@ class TestSaveLoadStateDict(DistributedTest):
ddp_model
.
state_dict
()
ddp_model
.
state_dict
()
# second call
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
def
test_state_dict_after_forward_mixed_precision
(
self
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_test_module_state_dict
,
{
"flatten_parameters"
:
False
,
"mixed_precision"
:
mixed_precision
}
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_state_dict_after_forward
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_module_state_dict
,
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
...
...
@@ -474,6 +472,62 @@ class TestSaveLoadStateDict(DistributedTest):
except
Exception
:
pass
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model
(
self
,
config
):
if
config
[
"mixed_precision"
]:
return
# TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model_local_state_dict
(
self
,
config
):
if
config
[
"mixed_precision"
]:
return
# TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model_local_state_dict
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_nested_wrapped_model
(
cls
,
rank
,
group
,
config
=
None
):
# Get reference state dict without any nested FSDP instances.
model
=
NestedWrappedModule
(
group
,
None
).
cuda
()
model
=
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
,
process_group
=
group
)
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
module
.
state_dict
().
items
()}
# Create a nested FSDP-wrapped instance.
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round-trip state dict save/load/save.
state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
state_dict
().
items
()}
model
.
load_state_dict
(
state_dict
)
state_dict
=
model
.
state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
@
classmethod
def
_test_nested_wrapped_model_local_state_dict
(
cls
,
rank
,
group
,
config
=
None
,
local
=
None
):
# Create a nested FSDP-wrapped instance.
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round trip state dict save/load/save.
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
local_state_dict
().
items
()}
model
.
load_local_state_dict
(
ref_state_dict
)
state_dict
=
model
.
local_state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
class
TestHooks
(
DistributedTest
):
# Feel free to modify these tests as the implementation changes.
...
...
@@ -689,7 +743,10 @@ class NestedWrappedModule(nn.Module):
torch
.
manual_seed
(
0
)
# keep everything deterministic
self
.
module
=
nn
.
Sequential
(
nn
.
Linear
(
8
,
4
),
_maybe_wrap
(
nn
.
Linear
(
4
,
16
)),
_maybe_wrap
(
nn
.
Linear
(
16
,
4
)),
nn
.
Linear
(
4
,
8
),
nn
.
Linear
(
8
,
4
),
_maybe_wrap
(
nn
.
Sequential
(
_maybe_wrap
(
nn
.
Linear
(
4
,
16
)),
nn
.
Linear
(
16
,
16
),)),
_maybe_wrap
(
nn
.
Linear
(
16
,
4
)),
nn
.
Linear
(
4
,
8
),
)
def
get_input
(
self
,
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment