Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b6dc98cf
Unverified
Commit
b6dc98cf
authored
Feb 26, 2021
by
Myle Ott
Committed by
GitHub
Feb 26, 2021
Browse files
[fix] fix FSDP state_dict/load_state_dict for nested wrapped instances (#440)
parent
93d115c6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
186 additions
and
48 deletions
+186
-48
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+122
-41
fairscale/utils/testing.py
fairscale/utils/testing.py
+1
-1
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+63
-6
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
b6dc98cf
...
@@ -29,6 +29,7 @@ from fairscale.utils.containers import (
...
@@ -29,6 +29,7 @@ from fairscale.utils.containers import (
)
)
from
fairscale.utils.parallel
import
chunk_and_pad
,
validate_process_group
from
fairscale.utils.parallel
import
chunk_and_pad
,
validate_process_group
from
fairscale.utils.reduce_scatter_bucketer
import
ReduceScatterBucketer
from
fairscale.utils.reduce_scatter_bucketer
import
ReduceScatterBucketer
from
fairscale.utils.state_dict
import
replace_by_prefix_
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
collections
import
OrderedDict
# noqa: F401
from
collections
import
OrderedDict
# noqa: F401
...
@@ -172,11 +173,11 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -172,11 +173,11 @@ class FullyShardedDataParallel(nn.Module):
params
=
list
(
p
for
p
in
module
.
parameters
()
if
not
hasattr
(
p
,
"_is_sharded"
))
params
=
list
(
p
for
p
in
module
.
parameters
()
if
not
hasattr
(
p
,
"_is_sharded"
))
if
self
.
flatten_parameters
and
len
(
params
)
>
0
:
if
self
.
flatten_parameters
and
len
(
params
)
>
0
:
self
.
module
:
nn
.
Module
=
FlattenParamsWrapper
(
module
,
param_list
=
params
)
self
.
_fsdp_wrapped_
module
:
nn
.
Module
=
FlattenParamsWrapper
(
module
,
param_list
=
params
)
del
module
# free original module in case it helps garbage collection
del
module
# free original module in case it helps garbage collection
self
.
params
=
[
self
.
module
.
flat_param
]
self
.
params
=
[
self
.
_fsdp_wrapped_
module
.
flat_param
]
else
:
else
:
self
.
module
=
module
self
.
_fsdp_wrapped_
module
=
module
self
.
params
=
params
self
.
params
=
params
# Shard module parameters in place
# Shard module parameters in place
...
@@ -192,8 +193,23 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -192,8 +193,23 @@ class FullyShardedDataParallel(nn.Module):
# pass. This will be False when inside the no_sync context manager.
# pass. This will be False when inside the no_sync context manager.
self
.
require_backward_grad_sync
:
bool
=
True
self
.
require_backward_grad_sync
:
bool
=
True
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self
.
training_state
=
TrainingState
.
IDLE
self
.
training_state
=
TrainingState
.
IDLE
# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back.
self
.
_register_state_dict_hook
(
_post_state_dict_hook
)
self
.
_register_load_state_dict_pre_hook
(
_pre_load_state_dict_hook
)
# Flag to indicate whether state_dict() should automatically summon the
# full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict().
self
.
_return_full_state_dict
=
True
@
property
def
module
(
self
)
->
nn
.
Module
:
return
self
.
_fsdp_wrapped_module
# note: may be a FlattenParamsWrapper instance
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_all_buffers_to
(
self
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
None
:
def
_all_buffers_to
(
self
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
None
:
"""Move all buffers to the specified device and dtype, recursively."""
"""Move all buffers to the specified device and dtype, recursively."""
...
@@ -235,6 +251,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -235,6 +251,7 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
primitives will be used.
"""
"""
self
.
_lazy_init
()
assert
self
.
_is_root
,
"clip_grad_norm should only be called on the root (parent) instance"
assert
self
.
_is_root
,
"clip_grad_norm should only be called on the root (parent) instance"
assert
self
.
training_state
==
TrainingState
.
IDLE
assert
self
.
training_state
==
TrainingState
.
IDLE
...
@@ -374,7 +391,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -374,7 +391,7 @@ class FullyShardedDataParallel(nn.Module):
self
.
_reset_lazy_init
()
self
.
_reset_lazy_init
()
# TODO (Min): figuring out how to do typing for this overloaded function.
# TODO (Min): figuring out how to do typing for this overloaded function.
def
state_dict
(
self
,
*
args
,
**
kwargs
)
:
# type: ignore
def
state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
"OrderedDict[str, torch.Tensor]"
:
# type: ignore
"""
"""
Returns the whole (unsharded) state of the module. Parameters are not
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
sharded, so the resulting state_dict can be loaded directly by the
...
@@ -384,16 +401,28 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -384,16 +401,28 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
primitives will be used.
"""
"""
with
self
.
summon_full_params
()
:
if
self
.
mixed_precision
:
# Buffers dtype stays consistent with parameters.
# Buffers dtype stays consistent with parameters.
self
.
_all_buffers_to
(
dtype
=
torch
.
float32
)
self
.
_all_buffers_to
(
dtype
=
torch
.
float32
)
state_dict
=
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
if
self
.
_return_full_state_dict
:
# We copy the state_dict since full param will be freed after
if
self
.
training_state
!=
TrainingState
.
SUMMON_FULL_PARAMS
:
# we exit the summon_full_params() context.
with
self
.
summon_full_params
():
for
key
in
state_dict
.
keys
():
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
state_dict
[
key
]
=
state_dict
[
key
].
clone
()
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
if
self
.
flatten_parameters
:
assert
isinstance
(
self
.
module
,
FlattenParamsWrapper
)
state_dict
=
self
.
module
.
flat_state_dict
(
*
args
,
**
kwargs
)
else
:
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
if
self
.
mixed_precision
:
# In case we are in mixed precision, restore buffers back to fp16.
# In case we are in mixed precision, restore buffers back to fp16.
self
.
_all_buffers_to
(
dtype
=
self
.
compute_dtype
)
self
.
_all_buffers_to
(
dtype
=
self
.
compute_dtype
)
return
state_dict
return
state_dict
...
@@ -405,12 +434,21 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -405,12 +434,21 @@ class FullyShardedDataParallel(nn.Module):
so the resulting state_dict can only be loaded after the Module has been
so the resulting state_dict can only be loaded after the Module has been
wrapped with FullyShardedDataParallel.
wrapped with FullyShardedDataParallel.
"""
"""
torch
.
cuda
.
synchronize
()
with
contextlib
.
ExitStack
()
as
stack
:
self
.
_lazy_init
()
# Tell any nested FSDP instances not to auto summon full params.
if
self
.
flatten_parameters
:
for
module
in
self
.
modules
():
# includes self
return
self
.
module
.
flat_state_dict
(
*
args
,
**
kwargs
)
# type: ignore
if
isinstance
(
module
,
FullyShardedDataParallel
):
else
:
stack
.
enter_context
(
module
.
_no_return_full_state_dict
())
return
self
.
module
.
state_dict
(
*
args
,
**
kwargs
)
return
self
.
state_dict
(
*
args
,
**
kwargs
)
@
contextlib
.
contextmanager
def
_no_return_full_state_dict
(
self
)
->
Generator
:
backup
=
self
.
_return_full_state_dict
self
.
_return_full_state_dict
=
False
try
:
yield
finally
:
self
.
_return_full_state_dict
=
backup
def
load_state_dict
(
def
load_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
...
@@ -421,16 +459,25 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -421,16 +459,25 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
primitives will be used.
"""
"""
with
self
.
summon_full_params
():
if
self
.
_return_full_state_dict
:
output
=
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
with
self
.
summon_full_params
():
return
output
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
def
load_local_state_dict
(
def
load_local_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
)
->
NamedTuple
:
)
->
NamedTuple
:
"""Load a local (sharded) state_dict."""
"""Load a local (sharded) state_dict."""
torch
.
cuda
.
synchronize
()
with
contextlib
.
ExitStack
()
as
stack
:
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
# Tell any nested FSDP instances not to auto summon full params.
for
module
in
self
.
modules
():
# includes self
if
isinstance
(
module
,
FullyShardedDataParallel
):
stack
.
enter_context
(
module
.
_no_return_full_state_dict
())
output
=
self
.
load_state_dict
(
state_dict
,
strict
)
return
output
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
no_sync
(
self
)
->
Generator
:
def
no_sync
(
self
)
->
Generator
:
...
@@ -457,30 +504,44 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -457,30 +504,44 @@ class FullyShardedDataParallel(nn.Module):
m
.
require_backward_grad_sync
=
old_flag
m
.
require_backward_grad_sync
=
old_flag
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
summon_full_params
(
self
)
->
Generator
:
def
summon_full_params
(
self
,
recurse
:
bool
=
True
)
->
Generator
:
"""
"""
A context manager to expose full params for the underlying model.
A context manager to expose full params for the current FSDP instance.
Can be useful *after* forward/backward for a model to get the params
Can be useful *after* forward/backward for a model to get the params for
for additional processing or checking.
additional processing or checking.
By default this will recursively summon all params for nested FSDP
instances; this can be disabled by setting ``recurse=False``.
This can be used on inner FSDPs.
.. note::
This can be used on inner FSDPs.
This can *not* be used within a forward or backward pass. Nor
can forward
.. note::
This can *not* be used within a forward or backward pass. Nor
and backward be started from within this context.
can forward
and backward be started from within this context.
"""
"""
torch
.
cuda
.
synchronize
()
if
recurse
:
self
.
_lazy_init
()
with
contextlib
.
ExitStack
()
as
stack
:
self
.
assert_state
(
TrainingState
.
IDLE
)
# summon all params for any nested FlattenParamsWrapper instances
# Set the state so that we assert when trying to go into
for
module
in
self
.
modules
():
# forward/backward.
if
isinstance
(
module
,
FullyShardedDataParallel
):
self
.
training_state
=
TrainingState
.
SUMMON_FULL_PARAMS
stack
.
enter_context
(
module
.
summon_full_params
(
recurse
=
False
))
self
.
_rebuild_full_params
()
# yield to the caller, with full params in all nested instances
try
:
yield
yield
# exiting from the ExitStack will re-shard params
finally
:
return
self
.
_free_full_params
()
else
:
self
.
_use_fp32_param_shard
()
torch
.
cuda
.
synchronize
()
self
.
training_state
=
TrainingState
.
IDLE
self
.
_lazy_init
()
self
.
assert_state
(
TrainingState
.
IDLE
)
# Set the state so that we assert when trying to go into
# forward/backward.
self
.
training_state
=
TrainingState
.
SUMMON_FULL_PARAMS
self
.
_rebuild_full_params
()
try
:
yield
finally
:
self
.
_free_full_params
()
self
.
_use_fp32_param_shard
()
self
.
training_state
=
TrainingState
.
IDLE
def
_reset_lazy_init
(
self
)
->
None
:
def
_reset_lazy_init
(
self
)
->
None
:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
...
@@ -1011,3 +1072,23 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
...
@@ -1011,3 +1072,23 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
return
return
assert
data
.
storage
().
size
()
==
0
assert
data
.
storage
().
size
()
==
0
data
.
storage
().
resize_
(
size
.
numel
())
data
.
storage
().
resize_
(
size
.
numel
())
def
_post_state_dict_hook
(
module
:
nn
.
Module
,
state_dict
:
"OrderedDict[str, torch.Tensor]"
,
prefix
:
str
,
*
args
:
Any
)
->
"OrderedDict[str, torch.Tensor]"
:
if
module
.
training_state
==
TrainingState
.
SUMMON_FULL_PARAMS
:
# We copy the state_dict since full param will be freed after
# we exit the summon_full_params() context.
for
key
in
state_dict
.
keys
():
state_dict
[
key
]
=
state_dict
[
key
].
clone
()
# Remove "_fsdp_wrapped_module." prefix
replace_by_prefix_
(
state_dict
,
prefix
+
"_fsdp_wrapped_module."
,
prefix
)
return
state_dict
def
_pre_load_state_dict_hook
(
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
prefix
:
str
,
*
args
:
Any
)
->
None
:
replace_by_prefix_
(
state_dict
,
prefix
,
prefix
+
"_fsdp_wrapped_module."
)
fairscale/utils/testing.py
View file @
b6dc98cf
...
@@ -415,7 +415,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
...
@@ -415,7 +415,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False) -> bool:
shape_dtype_device_match
=
a
.
size
()
==
b
.
size
()
and
a
.
dtype
==
b
.
dtype
and
a
.
device
==
b
.
device
shape_dtype_device_match
=
a
.
size
()
==
b
.
size
()
and
a
.
dtype
==
b
.
dtype
and
a
.
device
==
b
.
device
assert
shape_dtype_device_match
assert
shape_dtype_device_match
return
True
return
True
except
AssertionError
as
e
:
except
(
AssertionError
,
RuntimeError
)
as
e
:
if
raise_exception
:
if
raise_exception
:
raise
e
raise
e
else
:
else
:
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
b6dc98cf
...
@@ -434,11 +434,9 @@ class TestSaveLoadStateDict(DistributedTest):
...
@@ -434,11 +434,9 @@ class TestSaveLoadStateDict(DistributedTest):
ddp_model
.
state_dict
()
ddp_model
.
state_dict
()
ddp_model
.
state_dict
()
# second call
ddp_model
.
state_dict
()
# second call
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_state_dict_after_forward_mixed_precision
(
self
,
mixed_precision
):
def
test_state_dict_after_forward
(
self
,
config
):
test_fn
=
functools
.
partial
(
test_fn
=
functools
.
partial
(
self
.
_test_module_state_dict
,
config
)
self
.
_test_module_state_dict
,
{
"flatten_parameters"
:
False
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
...
@@ -474,6 +472,62 @@ class TestSaveLoadStateDict(DistributedTest):
...
@@ -474,6 +472,62 @@ class TestSaveLoadStateDict(DistributedTest):
except
Exception
:
except
Exception
:
pass
pass
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model
(
self
,
config
):
if
config
[
"mixed_precision"
]:
return
# TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model_local_state_dict
(
self
,
config
):
if
config
[
"mixed_precision"
]:
return
# TODO(myleott) this is broken until we support FP32 all-gather for state_dict
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model_local_state_dict
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_nested_wrapped_model
(
cls
,
rank
,
group
,
config
=
None
):
# Get reference state dict without any nested FSDP instances.
model
=
NestedWrappedModule
(
group
,
None
).
cuda
()
model
=
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
,
process_group
=
group
)
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
module
.
state_dict
().
items
()}
# Create a nested FSDP-wrapped instance.
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round-trip state dict save/load/save.
state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
state_dict
().
items
()}
model
.
load_state_dict
(
state_dict
)
state_dict
=
model
.
state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
@
classmethod
def
_test_nested_wrapped_model_local_state_dict
(
cls
,
rank
,
group
,
config
=
None
,
local
=
None
):
# Create a nested FSDP-wrapped instance.
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round trip state dict save/load/save.
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
local_state_dict
().
items
()}
model
.
load_local_state_dict
(
ref_state_dict
)
state_dict
=
model
.
local_state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
class
TestHooks
(
DistributedTest
):
class
TestHooks
(
DistributedTest
):
# Feel free to modify these tests as the implementation changes.
# Feel free to modify these tests as the implementation changes.
...
@@ -689,7 +743,10 @@ class NestedWrappedModule(nn.Module):
...
@@ -689,7 +743,10 @@ class NestedWrappedModule(nn.Module):
torch
.
manual_seed
(
0
)
# keep everything deterministic
torch
.
manual_seed
(
0
)
# keep everything deterministic
self
.
module
=
nn
.
Sequential
(
self
.
module
=
nn
.
Sequential
(
nn
.
Linear
(
8
,
4
),
_maybe_wrap
(
nn
.
Linear
(
4
,
16
)),
_maybe_wrap
(
nn
.
Linear
(
16
,
4
)),
nn
.
Linear
(
4
,
8
),
nn
.
Linear
(
8
,
4
),
_maybe_wrap
(
nn
.
Sequential
(
_maybe_wrap
(
nn
.
Linear
(
4
,
16
)),
nn
.
Linear
(
16
,
16
),)),
_maybe_wrap
(
nn
.
Linear
(
16
,
4
)),
nn
.
Linear
(
4
,
8
),
)
)
def
get_input
(
self
,
device
):
def
get_input
(
self
,
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment