Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
14abed6e
Unverified
Commit
14abed6e
authored
Apr 07, 2021
by
Myle Ott
Committed by
GitHub
Apr 07, 2021
Browse files
[FSDP] [feat] Add state_dict_device option (#579)
parent
121b9db0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
289 additions
and
169 deletions
+289
-169
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+36
-15
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+4
-0
tests/ci_test_list_2.txt
tests/ci_test_list_2.txt
+1
-0
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+0
-154
tests/nn/data_parallel/test_fsdp_state_dict.py
tests/nn/data_parallel/test_fsdp_state_dict.py
+248
-0
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
14abed6e
...
...
@@ -163,6 +163,10 @@ class FullyShardedDataParallel(nn.Module):
with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
where all but a few parameters can fit on one node.
Default: False
state_dict_device (torch.device, Optional):
device for parameters returned by :func:`state_dict`. If not given,
this will default to ``compute_dtype``. Note that only the device
type will be respected (e.g., "cuda:0" and "cuda:1" are the same).
"""
def
__init__
(
...
...
@@ -180,6 +184,7 @@ class FullyShardedDataParallel(nn.Module):
bucket_cap_mb
:
int
=
25
,
compute_device
:
Optional
[
torch
.
device
]
=
None
,
no_broadcast_optim_state
:
Optional
[
bool
]
=
False
,
state_dict_device
:
Optional
[
torch
.
device
]
=
None
,
):
super
().
__init__
()
self
.
process_group
=
process_group
or
dist
.
new_group
()
...
...
@@ -194,26 +199,21 @@ class FullyShardedDataParallel(nn.Module):
self
.
buffer_dtype
=
buffer_dtype
or
self
.
compute_dtype
self
.
move_grads_to_cpu
=
cpu_offload
if
move_grads_to_cpu
is
None
else
move_grads_to_cpu
self
.
bucket_cap_mb
=
bucket_cap_mb
self
.
compute_device
=
compute_device
or
_get_default_cuda_device
(
module
)
self
.
uncollected_opt_state
:
Dict
[
int
,
Dict
]
=
{}
self
.
no_broadcast_optim_state
=
no_broadcast_optim_state
self
.
state_dict_device
=
state_dict_device
or
self
.
compute_device
self
.
gradient_predivide_factor
:
int
=
self
.
get_gradient_predivide_factor
(
self
.
world_size
)
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
self
.
numel_padded_per_param
:
List
[
int
]
=
[]
self
.
compute_device
=
compute_device
if
self
.
fp32_reduce_scatter
and
not
self
.
mixed_precision
:
raise
ValueError
(
"fp32_reduce_scatter requires mixed_precision=True"
)
if
self
.
cpu_offload
and
not
self
.
mixed_precision
:
raise
ValueError
(
"cpu_offload requires mixed_precision=True"
)
if
self
.
compute_device
is
None
:
# Try to infer CUDA device from module parameters.
self
.
compute_device
=
next
(
module
.
parameters
()).
device
if
self
.
compute_device
.
type
!=
"cuda"
:
# Fall back to current CUDA device.
self
.
compute_device
=
torch
.
device
(
"cuda"
)
validate_process_group
(
self
.
compute_device
,
self
.
process_group
)
enable_pytorch_sync_bn
(
module
)
...
...
@@ -545,7 +545,7 @@ class FullyShardedDataParallel(nn.Module):
if
self
.
_return_full_state_dict
:
if
self
.
training_state
!=
TrainingState
.
SUMMON_FULL_PARAMS
:
with
self
.
summon_full_params
(
volatile
=
True
):
with
self
.
summon_full_params
(
recurse
=
False
,
volatile
=
True
):
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
else
:
state_dict
=
super
().
state_dict
(
*
args
,
**
kwargs
)
...
...
@@ -1410,7 +1410,7 @@ class FullyShardedDataParallel(nn.Module):
sd
[
"num_padded"
]
=
[
m
.
numel_padded_per_param
for
m
in
self
.
_fsdp_instances
]
else
:
sd
=
dummy_tensor
# type: ignore
sd
=
broadcast_object
(
sd
,
src_rank
=
rank
,
group
=
self
.
process_group
,
dist_device
=
self
.
compute_device
)
# type: ignore
sd
=
broadcast_object
(
sd
,
src_rank
=
rank
,
group
=
self
.
process_group
,
dist_device
=
self
.
compute_device
)
if
should_collect_state
:
assert
isinstance
(
sd
,
dict
),
f
"
{
self
.
rank
}
received
{
type
(
sd
)
}
from
{
rank
}
, expected dict"
all_states
.
append
(
recursive_copy_to_device
(
sd
,
non_blocking
=
False
,
device
=
torch
.
device
(
"cpu"
)))
...
...
@@ -1501,6 +1501,15 @@ class FullyShardedDataParallel(nn.Module):
return
full_optim_state_dict
def
_get_default_cuda_device
(
module
:
nn
.
Module
)
->
torch
.
device
:
"""Try to infer CUDA device from module parameters."""
compute_device
=
next
(
module
.
parameters
()).
device
if
compute_device
.
type
!=
"cuda"
:
# Fall back to current CUDA device.
compute_device
=
torch
.
device
(
"cuda"
)
return
compute_device
@
torch
.
no_grad
()
def
cast_inputs_to_fp16
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Tuple
[
Any
,
Any
]:
"""
...
...
@@ -1534,13 +1543,25 @@ def alloc_storage_(data: torch.Tensor, size: torch.Size) -> None:
def
_post_state_dict_hook
(
module
:
nn
.
Modu
le
,
state_dict
:
"OrderedDict[str, torch.Tensor]"
,
prefix
:
str
,
*
args
:
Any
module
:
FullyShardedDataParal
le
l
,
state_dict
:
"OrderedDict[str, torch.Tensor]"
,
prefix
:
str
,
*
args
:
Any
)
->
"OrderedDict[str, torch.Tensor]"
:
if
module
.
training_state
==
TrainingState
.
SUMMON_FULL_PARAMS
:
# We copy the state_dict since full param will be freed after
# we exit the summon_full_params() context.
for
key
in
state_dict
.
keys
():
# Assuming we are in a ``summon_full_params()`` context, we need to clone
# each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times
# recursively, so we need to make sure that we only clone each tensor at
# mostonce. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed.
for
key
in
state_dict
.
keys
():
if
not
key
.
startswith
(
prefix
)
or
getattr
(
state_dict
[
key
],
"_has_been_cloned"
,
False
):
continue
if
state_dict
[
key
].
device
.
type
!=
module
.
state_dict_device
.
type
:
state_dict
[
key
]
=
state_dict
[
key
].
to
(
device
=
module
.
state_dict_device
)
state_dict
[
key
].
_has_been_cloned
=
True
elif
module
.
training_state
==
TrainingState
.
SUMMON_FULL_PARAMS
:
# We copy the state_dict since full param will be freed after we
# exit the ``summon_full_params()`` context.
state_dict
[
key
]
=
state_dict
[
key
].
clone
()
state_dict
[
key
].
_has_been_cloned
=
True
# Remove "_fsdp_wrapped_module." prefix
replace_by_prefix_
(
state_dict
,
prefix
+
"_fsdp_wrapped_module."
,
prefix
)
...
...
stubs/torch/__init__.pyi
View file @
14abed6e
...
...
@@ -117,6 +117,10 @@ class Tensor:
data: Tensor = ...
names: List[str] = ...
#MODIFIED BY FULLY_SHARDED_DATA_PARALLEL
_has_been_cloned: Optional[bool] = ...
#END
def __init__(self, *args, **kwargs) -> None: ...
@property
...
...
tests/ci_test_list_2.txt
View file @
14abed6e
...
...
@@ -38,3 +38,4 @@ tests/nn/moe/test_top2gating.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py
tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_state_dict.py
tests/nn/data_parallel/test_fsdp.py
View file @
14abed6e
...
...
@@ -440,160 +440,6 @@ class TestSerialization(DistributedTest):
optim
.
step
()
class
TestLocalStateDict
(
DistributedTest
):
@
parameterized
.
expand
([[
True
,
True
],
[
False
,
False
]],
name_func
=
rename_test
)
def
test_load_local_state_dict
(
self
,
flatten_params
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_load_local_and_train
,
{
"flatten_parameters"
:
flatten_params
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_load_local_and_train
(
self
,
config
,
rank
,
group
,
d_model
=
16
,
d_vocab
=
23
):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
d_vocab
=
d_vocab
,
d_model
=
d_model
,
add_bn
=
False
)
# Set bn=True here to show that BN doesn't get updated
state_1
=
model
.
local_state_dict
()
state_before_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_1
.
items
()}
assert
len
(
state_1
)
>
0
model
.
load_local_state_dict
(
state_1
)
weight_key
=
"flat_param"
if
model
.
flatten_parameters
else
"embed_tokens.weight"
state_1_weight
=
state_1
[
weight_key
]
assert
state_1_weight
.
dtype
==
torch
.
float32
,
f
"got dtype
{
state_1_weight
.
dtype
}
expected torch.float32"
if
not
model
.
flatten_parameters
:
# The weight will be sharded since we access module.state_dict directly
state_1_module_weight
=
model
.
module
.
state_dict
()[
weight_key
]
torch
.
testing
.
assert_allclose
(
state_1_weight
,
state_1_module_weight
)
torch
.
testing
.
assert_allclose
(
state_1_weight
,
model
.
module
.
embed_tokens
.
weight
)
self
.
_train_for_several_steps
(
model
,
1
,
model
.
mixed_precision
)
state_2
=
model
.
local_state_dict
()
state_after_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_2
.
items
()}
model
.
load_local_state_dict
(
state_2
)
assert
state_1
.
keys
()
==
state_2
.
keys
()
# Assert that parameters were updated since before training
unchanged
=
[]
unwrapped_model
=
model
.
module
.
module
if
config
[
"flatten_parameters"
]
else
model
.
module
buffers
=
{
name
for
name
,
_
in
unwrapped_model
.
named_buffers
()}
for
k
in
state_1
:
if
(
state_before_training
[
k
]
==
state_after_training
[
k
]).
all
()
and
(
k
not
in
buffers
):
unchanged
.
append
(
k
)
if
unchanged
:
raise
AssertionError
(
f
"params
{
unchanged
}
not changed after training"
)
class
TestSaveLoadStateDict
(
DistributedTest
):
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
def
test_calling_state_dict_twice_mixed_precision
(
self
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_test_calling_state_dict_twice
,
{
"flatten_parameters"
:
False
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_calling_state_dict_twice
(
self
,
config
,
rank
,
group
,
**
model_kwargs
):
ddp_model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
**
model_kwargs
)
autocast
=
ddp_model
.
mixed_precision
self
.
_train_for_several_steps
(
ddp_model
,
1
,
autocast
)
ddp_model
.
state_dict
()
ddp_model
.
state_dict
()
# second call
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_state_dict_after_forward
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_module_state_dict
,
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
def
test_state_dict_before_forward
(
self
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_test_state_dict_before_forward
,
{
"flatten_parameters"
:
False
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_state_dict_before_forward
(
cls
,
config
,
rank
,
group
):
ddp_model
=
cls
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
)
sd
=
ddp_model
.
state_dict
()
wt
=
sd
[
"embed_tokens.weight"
]
assert
wt
.
dtype
==
torch
.
float32
,
f
"got dtype
{
wt
.
dtype
}
expected torch.float32"
cls
.
_train_for_several_steps
(
ddp_model
,
1
,
ddp_model
.
mixed_precision
)
@
classmethod
def
_test_module_state_dict
(
cls
,
config
,
rank
,
group
):
ddp_model
=
cls
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
)
autocast
=
ddp_model
.
mixed_precision
cls
.
_train_for_several_steps
(
ddp_model
,
2
,
autocast
)
state_1
=
ddp_model
.
state_dict
()
# You must make a new FullyShardedDataParallel instance to use module.load_state_dict
unwrapped_model
=
TransformerWithSharedParams
(
group
)
unwrapped_model
.
load_state_dict
(
state_1
)
new_ddp_model
=
FullyShardedDataParallel
(
unwrapped_model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
new_ddp_model
,
2
,
autocast
)
try
:
ddp_model
.
load_state_dict
(
new_ddp_model
.
state_dict
())
assert
False
,
"ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
except
Exception
:
pass
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model_local_state_dict
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model_local_state_dict
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_nested_wrapped_model
(
cls
,
rank
,
group
,
config
=
None
):
# Get reference state dict without any nested FSDP instances.
model
=
NestedWrappedModule
(
group
,
None
).
cuda
()
model
=
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
,
process_group
=
group
)
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
module
.
state_dict
().
items
()}
# Create a nested FSDP-wrapped instance.
if
config
[
"mixed_precision"
]:
config
[
"compute_dtype"
]
=
torch
.
float32
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round-trip state dict save/load/save.
state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
state_dict
().
items
()}
model
.
load_state_dict
(
state_dict
)
state_dict
=
model
.
state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
@
classmethod
def
_test_nested_wrapped_model_local_state_dict
(
cls
,
rank
,
group
,
config
=
None
,
local
=
None
):
# Create a nested FSDP-wrapped instance.
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round trip state dict save/load/save.
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
local_state_dict
().
items
()}
model
.
load_local_state_dict
(
ref_state_dict
)
state_dict
=
model
.
local_state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
class
TestHooks
(
DistributedTest
):
# Feel free to modify these tests as the implementation changes.
# They aspire to make sure that backward hooks are registered and used
...
...
tests/nn/data_parallel/test_fsdp_state_dict.py
0 → 100644
View file @
14abed6e
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
functools
import
unittest
from
parameterized
import
parameterized
import
torch
from
torch
import
nn
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
from
fairscale.utils.testing
import
objects_are_equal
from
.test_fsdp
import
(
CONFIG_OPTIONS
,
DistributedTest
,
NestedWrappedModule
,
TransformerWithSharedParams
,
rename_test
,
spawn_and_init
,
)
class
TestLocalStateDict
(
DistributedTest
):
@
parameterized
.
expand
([[
True
,
True
],
[
False
,
False
]],
name_func
=
rename_test
)
def
test_load_local_state_dict
(
self
,
flatten_params
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_load_local_and_train
,
{
"flatten_parameters"
:
flatten_params
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_load_local_and_train
(
self
,
config
,
rank
,
group
,
d_model
=
16
,
d_vocab
=
23
):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
d_vocab
=
d_vocab
,
d_model
=
d_model
,
add_bn
=
False
)
# Set bn=True here to show that BN doesn't get updated
state_1
=
model
.
local_state_dict
()
state_before_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_1
.
items
()}
assert
len
(
state_1
)
>
0
model
.
load_local_state_dict
(
state_1
)
weight_key
=
"flat_param"
if
model
.
flatten_parameters
else
"embed_tokens.weight"
state_1_weight
=
state_1
[
weight_key
]
assert
state_1_weight
.
dtype
==
torch
.
float32
,
f
"got dtype
{
state_1_weight
.
dtype
}
expected torch.float32"
if
not
model
.
flatten_parameters
:
# The weight will be sharded since we access module.state_dict directly
state_1_module_weight
=
model
.
module
.
state_dict
()[
weight_key
]
torch
.
testing
.
assert_allclose
(
state_1_weight
,
state_1_module_weight
)
torch
.
testing
.
assert_allclose
(
state_1_weight
,
model
.
module
.
embed_tokens
.
weight
)
self
.
_train_for_several_steps
(
model
,
1
,
model
.
mixed_precision
)
state_2
=
model
.
local_state_dict
()
state_after_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_2
.
items
()}
model
.
load_local_state_dict
(
state_2
)
assert
state_1
.
keys
()
==
state_2
.
keys
()
# Assert that parameters were updated since before training
unchanged
=
[]
unwrapped_model
=
model
.
module
.
module
if
config
[
"flatten_parameters"
]
else
model
.
module
buffers
=
{
name
for
name
,
_
in
unwrapped_model
.
named_buffers
()}
for
k
in
state_1
:
if
(
state_before_training
[
k
]
==
state_after_training
[
k
]).
all
()
and
(
k
not
in
buffers
):
unchanged
.
append
(
k
)
if
unchanged
:
raise
AssertionError
(
f
"params
{
unchanged
}
not changed after training"
)
class
TestSaveLoadStateDict
(
DistributedTest
):
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
def
test_calling_state_dict_twice_mixed_precision
(
self
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_test_calling_state_dict_twice
,
{
"flatten_parameters"
:
False
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_calling_state_dict_twice
(
self
,
config
,
rank
,
group
,
**
model_kwargs
):
ddp_model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
**
model_kwargs
)
autocast
=
ddp_model
.
mixed_precision
self
.
_train_for_several_steps
(
ddp_model
,
1
,
autocast
)
ddp_model
.
state_dict
()
ddp_model
.
state_dict
()
# second call
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_state_dict_after_forward
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_module_state_dict
,
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
([[
False
],
[
True
]],
name_func
=
rename_test
)
def
test_state_dict_before_forward
(
self
,
mixed_precision
):
test_fn
=
functools
.
partial
(
self
.
_test_state_dict_before_forward
,
{
"flatten_parameters"
:
False
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_state_dict_before_forward
(
cls
,
config
,
rank
,
group
):
ddp_model
=
cls
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
)
sd
=
ddp_model
.
state_dict
()
wt
=
sd
[
"embed_tokens.weight"
]
assert
wt
.
dtype
==
torch
.
float32
,
f
"got dtype
{
wt
.
dtype
}
expected torch.float32"
cls
.
_train_for_several_steps
(
ddp_model
,
1
,
ddp_model
.
mixed_precision
)
@
classmethod
def
_test_module_state_dict
(
cls
,
config
,
rank
,
group
):
ddp_model
=
cls
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
)
autocast
=
ddp_model
.
mixed_precision
cls
.
_train_for_several_steps
(
ddp_model
,
2
,
autocast
)
state_1
=
ddp_model
.
state_dict
()
# You must make a new FullyShardedDataParallel instance to use module.load_state_dict
unwrapped_model
=
TransformerWithSharedParams
(
group
)
unwrapped_model
.
load_state_dict
(
state_1
)
new_ddp_model
=
FullyShardedDataParallel
(
unwrapped_model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
new_ddp_model
,
2
,
autocast
)
try
:
ddp_model
.
load_state_dict
(
new_ddp_model
.
state_dict
())
assert
False
,
"ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded"
except
Exception
:
pass
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
(
CONFIG_OPTIONS
,
name_func
=
rename_test
)
def
test_nested_wrapped_model_local_state_dict
(
self
,
config
):
test_fn
=
functools
.
partial
(
self
.
_test_nested_wrapped_model_local_state_dict
,
config
=
config
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_nested_wrapped_model
(
cls
,
rank
,
group
,
config
=
None
):
# Get reference state dict without any nested FSDP instances.
model
=
NestedWrappedModule
(
group
,
None
).
cuda
()
model
=
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
,
process_group
=
group
)
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
module
.
state_dict
().
items
()}
# Create a nested FSDP-wrapped instance.
if
config
[
"mixed_precision"
]:
config
[
"compute_dtype"
]
=
torch
.
float32
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round-trip state dict save/load/save.
state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
state_dict
().
items
()}
model
.
load_state_dict
(
state_dict
)
state_dict
=
model
.
state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
@
classmethod
def
_test_nested_wrapped_model_local_state_dict
(
cls
,
rank
,
group
,
config
=
None
,
local
=
None
):
# Create a nested FSDP-wrapped instance.
model
=
NestedWrappedModule
(
group
,
config
)
model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
).
cuda
()
cls
.
_train_for_several_steps
(
model
,
2
,
autocast
=
config
[
"mixed_precision"
])
# Round trip state dict save/load/save.
ref_state_dict
=
{
k
:
v
.
clone
()
for
k
,
v
in
model
.
local_state_dict
().
items
()}
model
.
load_local_state_dict
(
ref_state_dict
)
state_dict
=
model
.
local_state_dict
()
assert
ref_state_dict
.
keys
()
==
state_dict
.
keys
(),
f
"
{
ref_state_dict
.
keys
()
}
!=
{
state_dict
.
keys
()
}
"
for
key
in
ref_state_dict
.
keys
():
assert
objects_are_equal
(
ref_state_dict
[
key
],
state_dict
[
key
],
raise_exception
=
False
),
f
"
{
key
}
,
{
ref_state_dict
[
key
]
}
!=
{
state_dict
[
key
]
}
"
class
TestStateDictDeviceDtype
(
DistributedTest
):
@
parameterized
.
expand
([[
False
,
False
],
[
True
,
False
],
[
True
,
True
]],
name_func
=
rename_test
)
def
test_state_dict_device
(
self
,
mixed_precision
,
cpu_offload
):
test_fn
=
functools
.
partial
(
self
.
_test_state_dict_device
,
{
"cpu_offload"
:
cpu_offload
,
"mixed_precision"
:
mixed_precision
}
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
([[
False
,
False
],
[
True
,
False
],
[
True
,
True
]],
name_func
=
rename_test
)
def
test_state_dict_device_cuda
(
self
,
mixed_precision
,
cpu_offload
):
test_fn
=
functools
.
partial
(
self
.
_test_state_dict_device
,
{
"cpu_offload"
:
cpu_offload
,
"mixed_precision"
:
mixed_precision
,
"state_dict_device"
:
torch
.
device
(
"cuda"
)},
)
spawn_and_init
(
test_fn
)
@
parameterized
.
expand
([[
False
,
False
],
[
True
,
False
],
[
True
,
True
]],
name_func
=
rename_test
)
def
test_state_dict_device_cpu
(
self
,
mixed_precision
,
cpu_offload
):
test_fn
=
functools
.
partial
(
self
.
_test_state_dict_device
,
{
"cpu_offload"
:
cpu_offload
,
"mixed_precision"
:
mixed_precision
,
"state_dict_device"
:
torch
.
device
(
"cpu"
)},
)
spawn_and_init
(
test_fn
)
def
test_state_dict_device_pure_fp16
(
self
):
test_fn
=
functools
.
partial
(
self
.
_test_state_dict_device
,
{
"cpu_offload"
:
False
,
"mixed_precision"
:
False
,
"compute_dtype"
:
torch
.
float16
},
# pure_fp16 is similar to the --memory-efficient-fp16 option in fairseq
pure_fp16
=
True
,
)
spawn_and_init
(
test_fn
)
@
classmethod
def
_test_state_dict_device
(
self
,
config
,
rank
,
group
,
pure_fp16
=
False
,
**
model_kwargs
):
model
=
TransformerWithSharedParams
(
group
,
**
model_kwargs
)
if
pure_fp16
:
assert
not
config
[
"mixed_precision"
]
model
=
model
.
half
()
fsdp_model
=
FullyShardedDataParallel
(
model
,
group
,
**
config
)
if
not
config
[
"cpu_offload"
]:
fsdp_model
=
fsdp_model
.
cuda
()
autocast
=
fsdp_model
.
mixed_precision
or
pure_fp16
self
.
_train_for_several_steps
(
fsdp_model
,
1
,
autocast
)
sd
=
fsdp_model
.
state_dict
()
sd_device
=
config
.
get
(
"state_dict_device"
)
for
k
,
v
in
sd
.
items
():
if
config
[
"cpu_offload"
]
or
(
sd_device
is
not
None
and
sd_device
.
type
==
"cpu"
):
assert
v
.
device
.
type
==
"cpu"
,
v
.
device
.
type
else
:
assert
v
.
device
.
type
==
"cuda"
,
v
.
device
.
type
expected_dtype
=
torch
.
float16
if
pure_fp16
else
torch
.
float32
buffers
=
{
k
.
replace
(
"_fsdp_wrapped_module."
,
""
).
replace
(
"_fpw_module."
,
""
)
for
k
,
_
in
fsdp_model
.
named_buffers
()
}
for
k
,
v
in
sd
.
items
():
if
not
torch
.
is_floating_point
(
v
):
continue
if
k
in
buffers
:
assert
v
.
dtype
==
fsdp_model
.
buffer_dtype
,
f
"
{
v
.
dtype
}
!=
{
fsdp_model
.
buffer_dtype
}
"
else
:
assert
v
.
dtype
==
expected_dtype
,
f
"
{
v
.
dtype
}
!=
{
expected_dtype
}
"
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment