Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
14abed6e
Unverified
Commit
14abed6e
authored
Apr 07, 2021
by
Myle Ott
Committed by
GitHub
Apr 07, 2021
Browse files
[FSDP] [feat] Add state_dict_device option (#579)
parent
121b9db0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
289 additions
and
169 deletions
+289
-169
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+36
-15
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+4
-0
tests/ci_test_list_2.txt
tests/ci_test_list_2.txt
+1
-0
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+0
-154
tests/nn/data_parallel/test_fsdp_state_dict.py
tests/nn/data_parallel/test_fsdp_state_dict.py
+248
-0
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
14abed6e
...
@@ -163,6 +163,10 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -163,6 +163,10 @@ class FullyShardedDataParallel(nn.Module):
with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
where all but a few parameters can fit on one node.
where all but a few parameters can fit on one node.
Default: False
Default: False
state_dict_device (torch.device, Optional):
device for parameters returned by :func:`state_dict`. If not given,
this will default to ``compute_dtype``. Note that only the device
type will be respected (e.g., "cuda:0" and "cuda:1" are the same).
"""
"""
def
__init__
(
def
__init__
(
...
@@ -180,6 +184,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -180,6 +184,7 @@ class FullyShardedDataParallel(nn.Module):
bucket_cap_mb
:
int
=
25
,
bucket_cap_mb
:
int
=
25
,
compute_device
:
Optional
[
torch
.
device
]
=
None
,
compute_device
:
Optional
[
torch
.
device
]
=
None
,
no_broadcast_optim_state
:
Optional
[
bool
]
=
False
,
no_broadcast_optim_state
:
Optional
[
bool
]
=
False
,
state_dict_device
:
Optional
[
torch
.
device
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
process_group
=
process_group
or
dist
.
new_group
()
self
.
process_group
=
process_group
or
dist
.
new_group
()
...
@@ -194,26 +199,21 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -194,26 +199,21 @@ class FullyShardedDataParallel(nn.Module):
self
.
buffer_dtype
=
buffer_dtype
or
self
.
compute_dtype
self
.
buffer_dtype
=
buffer_dtype
or
self
.
compute_dtype
self
.
move_grads_to_cpu
=
cpu_offload
if
move_grads_to_cpu
is
None
else
move_grads_to_cpu
self
.
move_grads_to_cpu
=
cpu_offload
if
move_grads_to_cpu
is
None
else
move_grads_to_cpu
self
.
bucket_cap_mb
=
bucket_cap_mb
self
.
bucket_cap_mb
=
bucket_cap_mb
self
.
compute_device
=
compute_device
or
_get_default_cuda_device
(
module
)
self
.
uncollected_opt_state
:
Dict
[
int
,
Dict
]
=
{}
self
.
uncollected_opt_state
:
Dict
[
int
,
Dict
]
=
{}
self
.
no_broadcast_optim_state
=
no_broadcast_optim_state
self
.
no_broadcast_optim_state
=
no_broadcast_optim_state
self
.
state_dict_device
=
state_dict_device
or
self
.
compute_device
self
.
gradient_predivide_factor
:
int
=
self
.
get_gradient_predivide_factor
(
self
.
world_size
)
self
.
gradient_predivide_factor
:
int
=
self
.
get_gradient_predivide_factor
(
self
.
world_size
)
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
self
.
numel_padded_per_param
:
List
[
int
]
=
[]
self
.
numel_padded_per_param
:
List
[
int
]
=
[]
self
.
compute_device
=
compute_device
if
self
.
fp32_reduce_scatter
and
not
self
.
mixed_precision
:
if
self
.
fp32_reduce_scatter
and
not
self
.
mixed_precision
:
raise
ValueError
(
"fp32_reduce_scatter requires mixed_precision=True"
)
raise
ValueError
(
"fp32_reduce_scatter requires mixed_precision=True"
)
if
self
.
cpu_offload
and
not
self
.
mixed_precision
:
if
self
.
cpu_offload
and
not
self
.
mixed_precision
:
raise
ValueError
(
"cpu_offload requires mixed_precision=True"
)
raise
ValueError
(
"cpu_offload requires mixed_precision=True"
)
if
self
.
compute_device
is
None
:
# Try to infer CUDA device from module parameters.
self
.
compute_device
=
next
(
module
.
parameters
()).
device
if
self
.
compute_device
.
type
!=
"cuda"
:
# Fall back to current CUDA device.
self
.
compute_device
=
torch
.
device
(
"cuda"
)
validate_process_group
(
self
.
compute_device
,
self
.
process_group
)
validate_process_group
(
self
.
compute_device
,
self
.
process_group
)
enable_pytorch_sync_bn
(
module
)
enable_pytorch_sync_bn
(
module
)
...
@@ -545,7 +545,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -545,7 +545,7 @@ class FullyShardedDataParallel(nn.Module):
if
self
.
_return_full_state_dict
:
if
self
.
_return_full_state_dict
:
if
self
.
training_state
!=
TrainingState
.
SUMMON_FULL_PARAMS
:
if
self
.
training_state
!=
TrainingState
.
SUMMON_FULL_PARAMS
:
with
self
.
summon_full_params
(
volatile
=
True
):
with
self
.
summon_full_params
(
recurse
=
False
,
volatile
=
True
):
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
else
:
else
:
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
...
@@ -1410,7 +1410,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -1410,7 +1410,7 @@ class FullyShardedDataParallel(nn.Module):
sd
[
"num_padded"
]
=
[
m
.
numel_padded_per_param
for
m
in
self
.
_fsdp_instances
]
sd
[
"num_padded"
]
=
[
m
.
numel_padded_per_param
for
m
in
self
.
_fsdp_instances
]
else
:
else
:
sd
=
dummy_tensor
# type: ignore
sd
=
dummy_tensor
# type: ignore
sd
=
broadcast_object
(
sd
,
src_rank
=
rank
,
group
=
self
.
process_group
,
dist_device
=
self
.
compute_device
)
# type: ignore
sd
=
broadcast_object
(
sd
,
src_rank
=
rank
,
group
=
self
.
process_group
,
dist_device
=
self
.
compute_device
)
if
should_collect_state
:
if
should_collect_state
:
assert
isinstance
(
sd
,
dict
),
f
"
{
self
.
rank
}
received
{
type
(
sd
)
}
from
{
rank
}
, expected dict"
assert
isinstance
(
sd
,
dict
),
f
"
{
self
.
rank
}
received
{
type
(
sd
)
}
from
{
rank
}
, expected dict"
all_states
.
append
(
recursive_copy_to_device
(
sd
,
non_blocking
=
False
,
device
=
torch
.
device
(
"cpu"
)))
all_states
.
append
(
recursive_copy_to_device
(
sd
,
non_blocking
=
False
,
device
=
torch
.
device
(
"cpu"
)))
...
@@ -1501,6 +1501,15 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -1501,6 +1501,15 @@ class FullyShardedDataParallel(nn.Module):
return
full_optim_state_dict
return
full_optim_state_dict
def
_get_default_cuda_device
(
module
:
nn
.
Module
)
->
torch
.
device
:
"""Try to infer CUDA device from module parameters."""
compute_device
=
next
(
module
.
parameters
()).
device
if
compute_device
.
type
!=
"cuda"
:
# Fall back to current CUDA device.
compute_device
=
torch
.
device
(
"cuda"
)
return
compute_device
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
cast_inputs_to_fp16
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Tuple
[
Any
,
Any
]:
def
cast_inputs_to_fp16
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Tuple
[
Any
,
Any
]:
"""
"""
...
@@ -1534,13 +1543,25 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
...
@@ -1534,13 +1543,25 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
def
_post_state_dict_hook
(
def
_post_state_dict_hook
(
module
:
nn
.
Modu
le
,
state_dict
:
"OrderedDict[str, torch.Tensor]"
,
prefix
:
str
,
*
args
:
Any
module
:
FullyShardedDataParal
le
l
,
state_dict
:
"OrderedDict[str, torch.Tensor]"
,
prefix
:
str
,
*
args
:
Any
)
->
"OrderedDict[str, torch.Tensor]"
:
)
->
"OrderedDict[str, torch.Tensor]"
:
if
module
.
training_state
==
TrainingState
.
SUMMON_FULL_PARAMS
:
# Assuming we are in a ``summon_full_params()`` context, we need to clone
# We copy the state_dict since full param will be freed after
# each tensor so that it does not get freed (in-place) when the context
# we exit the summon_full_params() context.
# exits. At the same time, this hook can be called multiple times
for
key
in
state_dict
.
keys
():
# recursively, so we need to make sure that we only clone each tensor at
# mostonce. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed.
for
key
in
state_dict
.
keys
():
if
not
key
.
startswith
(
prefix
)
or
getattr
(
state_dict
[
key
],
"_has_been_cloned"
,
False
):
continue
if
state_dict
[
key
].
device
.
type
!=
module
.
state_dict_device
.
type
:
state_dict
[
key
]
=
state_dict
[
key
].
to
(
device
=
module
.
state_dict_device
)
state_dict
[
key
].
_has_been_cloned
=
True
elif
module
.
training_state
==
TrainingState
.
SUMMON_FULL_PARAMS
:
# We copy the state_dict since full param will be freed after we
# exit the ``summon_full_params()`` context.
state_dict
[
key
]
=
state_dict
[
key
].
clone
()
state_dict
[
key
]
=
state_dict
[
key
].
clone
()
state_dict
[
key
].
_has_been_cloned
=
True
# Remove "_fsdp_wrapped_module." prefix
# Remove "_fsdp_wrapped_module." prefix
replace_by_prefix_
(
state_dict
,
prefix
+
"_fsdp_wrapped_module."
,
prefix
)
replace_by_prefix_
(
state_dict
,
prefix
+
"_fsdp_wrapped_module."
,
prefix
)
...
...
stubs/torch/__init__.pyi
View file @
14abed6e
...
@@ -117,6 +117,10 @@ class Tensor:
...
@@ -117,6 +117,10 @@ class Tensor:
data: Tensor = ...
data: Tensor = ...
names: List[str] = ...
names: List[str] = ...
#MODIFIED BY FULLY_SHARDED_DATA_PARALLEL
_has_been_cloned: Optional[bool] = ...
#END
def __init__(self, *args, **kwargs) -> None: ...
def __init__(self, *args, **kwargs) -> None: ...
@property
@property
...
...
tests/ci_test_list_2.txt
View file @
14abed6e
...
@@ -38,3 +38,4 @@ tests/nn/moe/test_top2gating.py
...
@@ -38,3 +38,4 @@ tests/nn/moe/test_top2gating.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py
tests/experimental/nn/test_offload.py
tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_state_dict.py
tests/nn/data_parallel/test_fsdp.py
View file @
14abed6e
...
@@ -440,160 +440,6 @@ class TestSerialization(DistributedTest):
...
@@ -440,160 +440,6 @@ class TestSerialization(DistributedTest):
optim
.
step
()
optim
.
step
()
class
TestLocalStateDict
(
DistributedTest
):
@
parameterized
.
expand
([[
True
,
True
],
[
False
,
False
]],
name_func
=
rename_test
)
def
test_load_local_state_dict
(
self
,
flatten_params
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_load_local_and_train
,
{
"flatten_parameters"
:
flatten_params
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_load_local_and_train
(
self
,
config
,
rank
,
group
,
d_model
=
16
,
d_vocab
=
23
):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
d_vocab
=
d_vocab
,
d_model
=
d_model
,
add_bn
=
False
)
# Set bn=True here to show that BN doesn't get updated
state_1
=
model
.
local_state_dict
()
state_before_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_1
.
items
()}
assert
len
(
state_1
)
>
0
model
.
load_local_state_dict
(
state_1
)
weight_key
=
"flat_param"
if
model
.
flatten_parameters
else
"embed_tokens.weight"
state_1_weight
=
state_1
[
weight_key
]
assert
state_1_weight
.
dtype
==
torch
.
float32
,
f
"got dtype
{
state_1_weight
.
dtype
}
expected torch.float32"
if
not
model
.
flatten_parameters
:
# The weight will be sharded since we access module.state_dict directly
state_1_module_weight
=
model
.
module
.
state_dict
()[
weight_key
]
torch
.
testing
.
assert_allclose
(
state_1_weight
,
state_1_module_weight
)
torch
.
testing
.
assert_allclose
(
state_1_weight
,
model
.
module
.
embed_tokens
.
weight
)
self
.
_train_for_several_steps
(
model
,
1
,
model
.
mixed_precision
)
state_2
=
model
.
local_state_dict
()
state_after_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_2
.
items
()}
model
.
load_local_state_dict
(
state_2
)
assert
state_1
.
keys
()
==
state_2
.
keys
()
# Assert that parameters were updated since before training
unchanged
=
[]
unwrapped_model
=
model
.
module
.
module
if
config
[
"flatten_parameters"
]
else
model
.
module
buffers
=
{
name
for
name
,
_
in
unwrapped_model
.
named_buffers
()}
for
k
in
state_1
:
if
(
state_before_training
[
k
]
==
state_after_training
[
k
]).
all
()
and
(
k
not
in
buffers
):
unchanged
.
append
(
k
)
if
unchanged
:
raise
AssertionError
(
f
"params
{
unchanged
}
not changed after training"
)
class
TestSaveLoadStateDict
(
DistributedTest
):
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
def
test_calling_state_dict_twice_mixed_precision
(
self
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_test_calling_state_dict_twice
,
{
"flatten_parameters"
:
False
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_calling_state_dict_twice
(
self
,
config
,
rank
,
group
,
**
model_kwargs
):
ddp_model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
**
model_kwargs
)
autocast
=
ddp_model
.
mixed_precision
self
.
_train_for_several_steps
(
ddp_model
,
1
,
autocast
)
ddp_model
.
state_dict
()
ddp_model
.
state_dict
()
# second call
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_state_dict_after_forward
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_module_state_dict
,
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
def
test_state_dict_before_forward
(
self
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_test_state_dict_before_forward
,
{
"flatten_parameters"
:
False
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_state_dict_before_forward
(
cls
,
config
,
rank
,
group
):
ddp_model
=
cls
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
)
sd
=
ddp_model
.
state_dict
()
wt
=
sd
[
"embed_tokens.weight"
]
assert
wt
.
dtype
==
torch
.
float32
,
f
"got dtype
{
wt
.
dtype
}
expected torch.float32"
cls
.
_train_for_several_steps
(
ddp_model
,
1
,
ddp_model
.
mixed_precision
)
@
classmethod
def
_test_module_state_dict
(
cls
,
config
,
rank
,
group
):
ddp_model
=
cls
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
)
autocast
=
ddp_model
.
mixed_precision
cls
.
_train_for_several_steps
(
ddp_model
,
2
,
autocast
)
state_1
=
ddp_model
.
state_dict
()
# You must make a new FullyShardedDataParallel instance to use module.load_state_dict
unwrapped_model
=
TransformerWithSharedParams
(
group
)
unwrapped_model
.
load_state_dict
(
state_1
)
new_ddp_model
=
FullyShardedDataParallel
(
unwrapped_model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
new_ddp_model
,
2
,
autocast
)
try
:
ddp_model
.
load_state_dict
(
new_ddp_model
.
state_dict
())
assert
False
,
"ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
except
Exception
:
pass
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model_local_state_dict
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model_local_state_dict
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_nested_wrapped_model
(
cls
,
rank
,
group
,
config
=
None
):
# Get reference state dict without any nested FSDP instances.
model
=
NestedWrappedModule
(
group
,
None
).
cuda
()
model
=
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
,
process_group
=
group
)
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
module
.
state_dict
().
items
()}
# Create a nested FSDP-wrapped instance.
if
config
[
"mixed_precision"
]:
config
[
"compute_dtype"
]
=
torch
.
float32
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round-trip state dict save/load/save.
state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
state_dict
().
items
()}
model
.
load_state_dict
(
state_dict
)
state_dict
=
model
.
state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
@
classmethod
def
_test_nested_wrapped_model_local_state_dict
(
cls
,
rank
,
group
,
config
=
None
,
local
=
None
):
# Create a nested FSDP-wrapped instance.
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round trip state dict save/load/save.
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
local_state_dict
().
items
()}
model
.
load_local_state_dict
(
ref_state_dict
)
state_dict
=
model
.
local_state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
class
TestHooks
(
DistributedTest
):
class
TestHooks
(
DistributedTest
):
# Feel free to modify these tests as the implementation changes.
# Feel free to modify these tests as the implementation changes.
# They aspire to make sure that backward hooks are registered and used
# They aspire to make sure that backward hooks are registered and used
...
...
tests/nn/data_parallel/test_fsdp_state_dict.py
0 → 100644
View file @
14abed6e
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
functools
import
unittest
from
parameterized
import
parameterized
import
torch
from
torch
import
nn
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
from
fairscale.utils.testing
import
objects_are_equal
from
.test_fsdp
import
(
CONFIG_OPTIONS
,
DistributedTest
,
NestedWrappedModule
,
TransformerWithSharedParams
,
rename_test
,
spawn_and_init
,
)
class
TestLocalStateDict
(
DistributedTest
):
@
parameterized
.
expand
([[
True
,
True
],
[
False
,
False
]],
name_func
=
rename_test
)
def
test_load_local_state_dict
(
self
,
flatten_params
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_load_local_and_train
,
{
"flatten_parameters"
:
flatten_params
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_load_local_and_train
(
self
,
config
,
rank
,
group
,
d_model
=
16
,
d_vocab
=
23
):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
d_vocab
=
d_vocab
,
d_model
=
d_model
,
add_bn
=
False
)
# Set bn=True here to show that BN doesn't get updated
state_1
=
model
.
local_state_dict
()
state_before_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_1
.
items
()}
assert
len
(
state_1
)
>
0
model
.
load_local_state_dict
(
state_1
)
weight_key
=
"flat_param"
if
model
.
flatten_parameters
else
"embed_tokens.weight"
state_1_weight
=
state_1
[
weight_key
]
assert
state_1_weight
.
dtype
==
torch
.
float32
,
f
"got dtype
{
state_1_weight
.
dtype
}
expected torch.float32"
if
not
model
.
flatten_parameters
:
# The weight will be sharded since we access module.state_dict directly
state_1_module_weight
=
model
.
module
.
state_dict
()[
weight_key
]
torch
.
testing
.
assert_allclose
(
state_1_weight
,
state_1_module_weight
)
torch
.
testing
.
assert_allclose
(
state_1_weight
,
model
.
module
.
embed_tokens
.
weight
)
self
.
_train_for_several_steps
(
model
,
1
,
model
.
mixed_precision
)
state_2
=
model
.
local_state_dict
()
state_after_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_2
.
items
()}
model
.
load_local_state_dict
(
state_2
)
assert
state_1
.
keys
()
==
state_2
.
keys
()
# Assert that parameters were updated since before training
unchanged
=
[]
unwrapped_model
=
model
.
module
.
module
if
config
[
"flatten_parameters"
]
else
model
.
module
buffers
=
{
name
for
name
,
_
in
unwrapped_model
.
named_buffers
()}
for
k
in
state_1
:
if
(
state_before_training
[
k
]
==
state_after_training
[
k
]).
all
()
and
(
k
not
in
buffers
):
unchanged
.
append
(
k
)
if
unchanged
:
raise
AssertionError
(
f
"params
{
unchanged
}
not changed after training"
)
class
TestSaveLoadStateDict
(
DistributedTest
):
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
def
test_calling_state_dict_twice_mixed_precision
(
self
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_test_calling_state_dict_twice
,
{
"flatten_parameters"
:
False
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_calling_state_dict_twice
(
self
,
config
,
rank
,
group
,
**
model_kwargs
):
ddp_model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
**
model_kwargs
)
autocast
=
ddp_model
.
mixed_precision
self
.
_train_for_several_steps
(
ddp_model
,
1
,
autocast
)
ddp_model
.
state_dict
()
ddp_model
.
state_dict
()
# second call
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_state_dict_after_forward
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_module_state_dict
,
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
def
test_state_dict_before_forward
(
self
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_test_state_dict_before_forward
,
{
"flatten_parameters"
:
False
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_state_dict_before_forward
(
cls
,
config
,
rank
,
group
):
ddp_model
=
cls
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
)
sd
=
ddp_model
.
state_dict
()
wt
=
sd
[
"embed_tokens.weight"
]
assert
wt
.
dtype
==
torch
.
float32
,
f
"got dtype
{
wt
.
dtype
}
expected torch.float32"
cls
.
_train_for_several_steps
(
ddp_model
,
1
,
ddp_model
.
mixed_precision
)
@
classmethod
def
_test_module_state_dict
(
cls
,
config
,
rank
,
group
):
ddp_model
=
cls
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
)
autocast
=
ddp_model
.
mixed_precision
cls
.
_train_for_several_steps
(
ddp_model
,
2
,
autocast
)
state_1
=
ddp_model
.
state_dict
()
# You must make a new FullyShardedDataParallel instance to use module.load_state_dict
unwrapped_model
=
TransformerWithSharedParams
(
group
)
unwrapped_model
.
load_state_dict
(
state_1
)
new_ddp_model
=
FullyShardedDataParallel
(
unwrapped_model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
new_ddp_model
,
2
,
autocast
)
try
:
ddp_model
.
load_state_dict
(
new_ddp_model
.
state_dict
())
assert
False
,
"ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
except
Exception
:
pass
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model_local_state_dict
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model_local_state_dict
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_nested_wrapped_model
(
cls
,
rank
,
group
,
config
=
None
):
# Get reference state dict without any nested FSDP instances.
model
=
NestedWrappedModule
(
group
,
None
).
cuda
()
model
=
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
,
process_group
=
group
)
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
module
.
state_dict
().
items
()}
# Create a nested FSDP-wrapped instance.
if
config
[
"mixed_precision"
]:
config
[
"compute_dtype"
]
=
torch
.
float32
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round-trip state dict save/load/save.
state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
state_dict
().
items
()}
model
.
load_state_dict
(
state_dict
)
state_dict
=
model
.
state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
@
classmethod
def
_test_nested_wrapped_model_local_state_dict
(
cls
,
rank
,
group
,
config
=
None
,
local
=
None
):
# Create a nested FSDP-wrapped instance.
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round trip state dict save/load/save.
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
local_state_dict
().
items
()}
model
.
load_local_state_dict
(
ref_state_dict
)
state_dict
=
model
.
local_state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
class
TestStateDictDeviceDtype
(
DistributedTest
):
@
parameterized
.
expand
([[
False
,
False
],
[
True
,
False
],
[
True
,
True
]],
name_func
=
rename_test
)
def
test_state_dict_device
(
self
,
mixed_precision
,
cpu_offload
):
test_fn
=
functools
.
partial
(
self
.
_test_state_dict_device
,
{
"cpu_offload"
:
cpu_offload
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
([[
False
,
False
],
[
True
,
False
],
[
True
,
True
]],
name_func
=
rename_test
)
def
test_state_dict_device_cuda
(
self
,
mixed_precision
,
cpu_offload
):
test_fn
=
functools
.
partial
(
self
.
_test_state_dict_device
,
{
"cpu_offload"
:
cpu_offload
,
"mixed_precision"
:
mixed_precision
,
"state_dict_device"
:
torch
.
device
(
"cuda"
)},
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
([[
False
,
False
],
[
True
,
False
],
[
True
,
True
]],
name_func
=
rename_test
)
def
test_state_dict_device_cpu
(
self
,
mixed_precision
,
cpu_offload
):
test_fn
=
functools
.
partial
(
self
.
_test_state_dict_device
,
{
"cpu_offload"
:
cpu_offload
,
"mixed_precision"
:
mixed_precision
,
"state_dict_device"
:
torch
.
device
(
"cpu"
)},
)
spawn_and_init
(
test_fn
)
def
test_state_dict_device_pure_fp16
(
self
):
test_fn
=
functools
.
partial
(
self
.
_test_state_dict_device
,
{
"cpu_offload"
:
False
,
"mixed_precision"
:
False
,
"compute_dtype"
:
torch
.
float16
},
# pure_fp16 is similar to the --memory-efficient-fp16 option in fairseq
pure_fp16
=
True
,
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_state_dict_device
(
self
,
config
,
rank
,
group
,
pure_fp16
=
False
,
**
model_kwargs
):
model
=
TransformerWithSharedParams
(
group
,
**
model_kwargs
)
if
pure_fp16
:
assert
not
config
[
"mixed_precision"
]
model
=
model
.
half
()
fsdp_model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
)
if
not
config
[
"cpu_offload"
]:
fsdp_model
=
fsdp_model
.
cuda
()
autocast
=
fsdp_model
.
mixed_precision
or
pure_fp16
self
.
_train_for_several_steps
(
fsdp_model
,
1
,
autocast
)
sd
=
fsdp_model
.
state_dict
()
sd_device
=
config
.
get
(
"state_dict_device"
)
for
k
,
v
in
sd
.
items
():
if
config
[
"cpu_offload"
]
or
(
sd_device
is
not
None
and
sd_device
.
type
==
"cpu"
):
assert
v
.
device
.
type
==
"cpu"
,
v
.
device
.
type
else
:
assert
v
.
device
.
type
==
"cuda"
,
v
.
device
.
type
expected_dtype
=
torch
.
float16
if
pure_fp16
else
torch
.
float32
buffers
=
{
k
.
replace
(
"_fsdp_wrapped_module."
,
""
).
replace
(
"_fpw_module."
,
""
)
for
k
,
_
in
fsdp_model
.
named_buffers
()
}
for
k
,
v
in
sd
.
items
():
if
not
torch
.
is_floating_point
(
v
):
continue
if
k
in
buffers
:
assert
v
.
dtype
==
fsdp_model
.
buffer_dtype
,
f
"
{
v
.
dtype
}
!=
{
fsdp_model
.
buffer_dtype
}
"
else
:
assert
v
.
dtype
==
expected_dtype
,
f
"
{
v
.
dtype
}
!=
{
expected_dtype
}
"
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment