Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b6dc98cf
Unverified
Commit
b6dc98cf
authored
Feb 26, 2021
by
Myle Ott
Committed by
GitHub
Feb 26, 2021
Browse files
[fix] fix FSDP state_dict/load_state_dict for nested wrapped instances (#440)
parent
93d115c6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
186 additions
and
48 deletions
+186
-48
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+122
-41
fairscale/utils/testing.py
fairscale/utils/testing.py
+1
-1
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+63
-6
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
b6dc98cf
...
@@ -29,6 +29,7 @@ from fairscale.utils.containers import (
...
@@ -29,6 +29,7 @@ from fairscale.utils.containers import (
)
)
from
fairscale.utils.parallel
import
chunk_and_pad
,
validate_process_group
from
fairscale.utils.parallel
import
chunk_and_pad
,
validate_process_group
from
fairscale.utils.reduce_scatter_bucketer
import
ReduceScatterBucketer
from
fairscale.utils.reduce_scatter_bucketer
import
ReduceScatterBucketer
from
fairscale.utils.state_dict
import
replace_by_prefix_
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
collections
import
OrderedDict
# noqa: F401
from
collections
import
OrderedDict
# noqa: F401
...
@@ -172,11 +173,11 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -172,11 +173,11 @@ class FullyShardedDataParallel(nn.Module):
params
=
list
(
p
for
p
in
module
.
parameters
()
if
not
hasattr
(
p
,
"_is_sharded"
))
params
=
list
(
p
for
p
in
module
.
parameters
()
if
not
hasattr
(
p
,
"_is_sharded"
))
if
self
.
flatten_parameters
and
len
(
params
)
>
0
:
if
self
.
flatten_parameters
and
len
(
params
)
>
0
:
self
.
module
:
nn
.
Module
=
FlattenParamsWrapper
(
module
,
param_list
=
params
)
self
.
_fsdp_wrapped_
module
:
nn
.
Module
=
FlattenParamsWrapper
(
module
,
param_list
=
params
)
del
module
# free original module in case it helps garbage collection
del
module
# free original module in case it helps garbage collection
self
.
params
=
[
self
.
module
.
flat_param
]
self
.
params
=
[
self
.
_fsdp_wrapped_
module
.
flat_param
]
else
:
else
:
self
.
module
=
module
self
.
_fsdp_wrapped_
module
=
module
self
.
params
=
params
self
.
params
=
params
# Shard module parameters in place
# Shard module parameters in place
...
@@ -192,8 +193,23 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -192,8 +193,23 @@ class FullyShardedDataParallel(nn.Module):
# pass. This will be False when inside the no_sync context manager.
# pass. This will be False when inside the no_sync context manager.
self
.
require_backward_grad_sync
:
bool
=
True
self
.
require_backward_grad_sync
:
bool
=
True
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self
.
training_state
=
TrainingState
.
IDLE
self
.
training_state
=
TrainingState
.
IDLE
# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back.
self
.
_register_state_dict_hook
(
_post_state_dict_hook
)
self
.
_register_load_state_dict_pre_hook
(
_pre_load_state_dict_hook
)
# Flag to indicate whether state_dict() should automatically summon the
# full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict().
self
.
_return_full_state_dict
=
True
@
property
def
module
(
self
)
->
nn
.
Module
:
return
self
.
_fsdp_wrapped_module
# note: may be a FlattenParamsWrapper instance
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_all_buffers_to
(
self
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
None
:
def
_all_buffers_to
(
self
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
None
:
"""Move all buffers to the specified device and dtype, recursively."""
"""Move all buffers to the specified device and dtype, recursively."""
...
@@ -235,6 +251,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -235,6 +251,7 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
primitives will be used.
"""
"""
self
.
_lazy_init
()
assert
self
.
_is_root
,
"clip_grad_norm should only be called on the root (parent) instance"
assert
self
.
_is_root
,
"clip_grad_norm should only be called on the root (parent) instance"
assert
self
.
training_state
==
TrainingState
.
IDLE
assert
self
.
training_state
==
TrainingState
.
IDLE
...
@@ -374,7 +391,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -374,7 +391,7 @@ class FullyShardedDataParallel(nn.Module):
self
.
_reset_lazy_init
()
self
.
_reset_lazy_init
()
# TODO (Min): figuring out how to do typing for this overloaded function.
# TODO (Min): figuring out how to do typing for this overloaded function.
def
state_dict
(
self
,
*
args
,
**
kwargs
)
:
# type: ignore
def
state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
"OrderedDict[str, torch.Tensor]"
:
# type: ignore
"""
"""
Returns the whole (unsharded) state of the module. Parameters are not
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
sharded, so the resulting state_dict can be loaded directly by the
...
@@ -384,16 +401,28 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -384,16 +401,28 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
primitives will be used.
"""
"""
with
self
.
summon_full_params
()
:
if
self
.
mixed_precision
:
# Buffers dtype stays consistent with parameters.
# Buffers dtype stays consistent with parameters.
self
.
_all_buffers_to
(
dtype
=
torch
.
float32
)
self
.
_all_buffers_to
(
dtype
=
torch
.
float32
)
state_dict
=
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
if
self
.
_return_full_state_dict
:
# We copy the state_dict since full param will be freed after
if
self
.
training_state
!=
TrainingState
.
SUMMON_FULL_PARAMS
:
# we exit the summon_full_params() context.
with
self
.
summon_full_params
():
for
key
in
state_dict
.
keys
():
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
state_dict
[
key
]
=
state_dict
[
key
].
clone
()
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
if
self
.
flatten_parameters
:
assert
isinstance
(
self
.
module
,
FlattenParamsWrapper
)
state_dict
=
self
.
module
.
flat_state_dict
(
*
args
,
**
kwargs
)
else
:
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
if
self
.
mixed_precision
:
# In case we are in mixed precision, restore buffers back to fp16.
# In case we are in mixed precision, restore buffers back to fp16.
self
.
_all_buffers_to
(
dtype
=
self
.
compute_dtype
)
self
.
_all_buffers_to
(
dtype
=
self
.
compute_dtype
)
return
state_dict
return
state_dict
...
@@ -405,12 +434,21 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -405,12 +434,21 @@ class FullyShardedDataParallel(nn.Module):
so the resulting state_dict can only be loaded after the Module has been
so the resulting state_dict can only be loaded after the Module has been
wrapped with FullyShardedDataParallel.
wrapped with FullyShardedDataParallel.
"""
"""
torch
.
cuda
.
synchronize
()
with
contextlib
.
ExitStack
()
as
stack
:
self
.
_lazy_init
()
# Tell any nested FSDP instances not to auto summon full params.
if
self
.
flatten_parameters
:
for
module
in
self
.
modules
():
# includes self
return
self
.
module
.
flat_state_dict
(
*
args
,
**
kwargs
)
# type: ignore
if
isinstance
(
module
,
FullyShardedDataParallel
):
else
:
stack
.
enter_context
(
module
.
_no_return_full_state_dict
())
return
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
return
self
.
state_dict
(
*
args
,
**
kwargs
)
@
contextlib
.
contextmanager
def
_no_return_full_state_dict
(
self
)
->
Generator
:
backup
=
self
.
_return_full_state_dict
self
.
_return_full_state_dict
=
False
try
:
yield
finally
:
self
.
_return_full_state_dict
=
backup
def
load_state_dict
(
def
load_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
...
@@ -421,16 +459,25 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -421,16 +459,25 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
primitives will be used.
"""
"""
with
self
.
summon_full_params
():
if
self
.
_return_full_state_dict
:
output
=
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
with
self
.
summon_full_params
():
return
output
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
def
load_local_state_dict
(
def
load_local_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
)
->
NamedTuple
:
)
->
NamedTuple
:
"""Load a local (sharded) state_dict."""
"""Load a local (sharded) state_dict."""
torch
.
cuda
.
synchronize
()
with
contextlib
.
ExitStack
()
as
stack
:
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
# Tell any nested FSDP instances not to auto summon full params.
for
module
in
self
.
modules
():
# includes self
if
isinstance
(
module
,
FullyShardedDataParallel
):
stack
.
enter_context
(
module
.
_no_return_full_state_dict
())
output
=
self
.
load_state_dict
(
state_dict
,
strict
)
return
output
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
no_sync
(
self
)
->
Generator
:
def
no_sync
(
self
)
->
Generator
:
...
@@ -457,30 +504,44 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -457,30 +504,44 @@ class FullyShardedDataParallel(nn.Module):
m
.
require_backward_grad_sync
=
old_flag
m
.
require_backward_grad_sync
=
old_flag
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
summon_full_params
(
self
)
->
Generator
:
def
summon_full_params
(
self
,
recurse
:
bool
=
True
)
->
Generator
:
"""
"""
A context manager to expose full params for the underlying model.
A context manager to expose full params for the current FSDP instance.
Can be useful *after* forward/backward for a model to get the params
Can be useful *after* forward/backward for a model to get the params for
for additional processing or checking.
additional processing or checking.
By default this will recursively summon all params for nested FSDP
instances; this can be disabled by setting ``recurse=False``.
This can be used on inner FSDPs.
.. note::
This can be used on inner FSDPs.
This can *not* be used within a forward or backward pass. Nor
can forward
.. note::
This can *not* be used within a forward or backward pass. Nor
and backward be started from within this context.
can forward
and backward be started from within this context.
"""
"""
torch
.
cuda
.
synchronize
()
if
recurse
:
self
.
_lazy_init
()
with
contextlib
.
ExitStack
()
as
stack
:
self
.
assert_state
(
TrainingState
.
IDLE
)
# summon all params for any nested FlattenParamsWrapper instances
# Set the state so that we assert when trying to go into
for
module
in
self
.
modules
():
# forward/backward.
if
isinstance
(
module
,
FullyShardedDataParallel
):
self
.
training_state
=
TrainingState
.
SUMMON_FULL_PARAMS
stack
.
enter_context
(
module
.
summon_full_params
(
recurse
=
False
))
self
.
_rebuild_full_params
()
# yield to the caller, with full params in all nested instances
try
:
yield
yield
# exiting from the ExitStack will re-shard params
finally
:
return
self
.
_free_full_params
()
else
:
self
.
_use_fp32_param_shard
()
torch
.
cuda
.
synchronize
()
self
.
training_state
=
TrainingState
.
IDLE
self
.
_lazy_init
()
self
.
assert_state
(
TrainingState
.
IDLE
)
# Set the state so that we assert when trying to go into
# forward/backward.
self
.
training_state
=
TrainingState
.
SUMMON_FULL_PARAMS
self
.
_rebuild_full_params
()
try
:
yield
finally
:
self
.
_free_full_params
()
self
.
_use_fp32_param_shard
()
self
.
training_state
=
TrainingState
.
IDLE
def
_reset_lazy_init
(
self
)
->
None
:
def
_reset_lazy_init
(
self
)
->
None
:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
...
@@ -1011,3 +1072,23 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
...
@@ -1011,3 +1072,23 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
return
return
assert
data
.
storage
().
size
()
==
0
assert
data
.
storage
().
size
()
==
0
data
.
storage
().
resize_
(
size
.
numel
())
data
.
storage
().
resize_
(
size
.
numel
())
def
_post_state_dict_hook
(
module
:
nn
.
Module
,
state_dict
:
"OrderedDict[str, torch.Tensor]"
,
prefix
:
str
,
*
args
:
Any
)
->
"OrderedDict[str, torch.Tensor]"
:
if
module
.
training_state
==
TrainingState
.
SUMMON_FULL_PARAMS
:
# We copy the state_dict since full param will be freed after
# we exit the summon_full_params() context.
for
key
in
state_dict
.
keys
():
state_dict
[
key
]
=
state_dict
[
key
].
clone
()
# Remove "_fsdp_wrapped_module." prefix
replace_by_prefix_
(
state_dict
,
prefix
+
"_fsdp_wrapped_module."
,
prefix
)
return
state_dict
def
_pre_load_state_dict_hook
(
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
prefix
:
str
,
*
args
:
Any
)
->
None
:
replace_by_prefix_
(
state_dict
,
prefix
,
prefix
+
"_fsdp_wrapped_module."
)
fairscale/utils/testing.py
View file @
b6dc98cf
...
@@ -415,7 +415,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
...
@@ -415,7 +415,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
shape_dtype_device_match
=
a
.
size
()
==
b
.
size
()
and
a
.
dtype
==
b
.
dtype
and
a
.
device
==
b
.
device
shape_dtype_device_match
=
a
.
size
()
==
b
.
size
()
and
a
.
dtype
==
b
.
dtype
and
a
.
device
==
b
.
device
assert
shape_dtype_device_match
assert
shape_dtype_device_match
return
True
return
True
except
AssertionError
as
e
:
except
(
AssertionError
,
RuntimeError
)
as
e
:
if
raise_exception
:
if
raise_exception
:
raise
e
raise
e
else
:
else
:
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
b6dc98cf
...
@@ -434,11 +434,9 @@ class TestSaveLoadStateDict(DistributedTest):
...
@@ -434,11 +434,9 @@ class TestSaveLoadStateDict(DistributedTest):
ddp_model
.
state_dict
()
ddp_model
.
state_dict
()
ddp_model
.
state_dict
()
# second call
ddp_model
.
state_dict
()
# second call
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_state_dict_after_forward_mixed_precision
(
self
,
mixed_precision
):
def
test_state_dict_after_forward
(
self
,
config
):
test_fn
=
functools
.
partial
(
test_fn
=
functools
.
partial
(
self
.
_test_module_state_dict
,
config
)
self
.
_test_module_state_dict
,
{
"flatten_parameters"
:
False
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
...
@@ -474,6 +472,62 @@ class TestSaveLoadStateDict(DistributedTest):
...
@@ -474,6 +472,62 @@ class TestSaveLoadStateDict(DistributedTest):
except
Exception
:
except
Exception
:
pass
pass
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model
(
self
,
config
):
if
config
[
"mixed_precision"
]:
return
# TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model_local_state_dict
(
self
,
config
):
if
config
[
"mixed_precision"
]:
return
# TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model_local_state_dict
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_nested_wrapped_model
(
cls
,
rank
,
group
,
config
=
None
):
# Get reference state dict without any nested FSDP instances.
model
=
NestedWrappedModule
(
group
,
None
).
cuda
()
model
=
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
,
process_group
=
group
)
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
module
.
state_dict
().
items
()}
# Create a nested FSDP-wrapped instance.
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round-trip state dict save/load/save.
state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
state_dict
().
items
()}
model
.
load_state_dict
(
state_dict
)
state_dict
=
model
.
state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
@
classmethod
def
_test_nested_wrapped_model_local_state_dict
(
cls
,
rank
,
group
,
config
=
None
,
local
=
None
):
# Create a nested FSDP-wrapped instance.
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round trip state dict save/load/save.
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
local_state_dict
().
items
()}
model
.
load_local_state_dict
(
ref_state_dict
)
state_dict
=
model
.
local_state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
class
TestHooks
(
DistributedTest
):
class
TestHooks
(
DistributedTest
):
# Feel free to modify these tests as the implementation changes.
# Feel free to modify these tests as the implementation changes.
...
@@ -689,7 +743,10 @@ class NestedWrappedModule(nn.Module):
...
@@ -689,7 +743,10 @@ class NestedWrappedModule(nn.Module):
torch
.
manual_seed
(
0
)
# keep everything deterministic
torch
.
manual_seed
(
0
)
# keep everything deterministic
self
.
module
=
nn
.
Sequential
(
self
.
module
=
nn
.
Sequential
(
nn
.
Linear
(
8
,
4
),
_maybe_wrap
(
nn
.
Linear
(
4
,
16
)),
_maybe_wrap
(
nn
.
Linear
(
16
,
4
)),
nn
.
Linear
(
4
,
8
),
nn
.
Linear
(
8
,
4
),
_maybe_wrap
(
nn
.
Sequential
(
_maybe_wrap
(
nn
.
Linear
(
4
,
16
)),
nn
.
Linear
(
16
,
16
),)),
_maybe_wrap
(
nn
.
Linear
(
16
,
4
)),
nn
.
Linear
(
4
,
8
),
)
)
def
get_input
(
self
,
device
):
def
get_input
(
self
,
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment