Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
5ed53a3e
Unverified
Commit
5ed53a3e
authored
Jul 26, 2022
by
Min Xu
Committed by
GitHub
Jul 26, 2022
Browse files
[fix] handle EMA in the state_dict (#1044)
* [fix] handle EMA in the state_dict * better fix
parent
4cb293e8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
42 additions
and
16 deletions
+42
-16
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+42
-16
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
5ed53a3e
...
...
@@ -275,7 +275,7 @@ class FullyShardedDataParallel(nn.Module):
Default: False
state_dict_device (torch.device, Optional):
device for parameters returned by :func:`state_dict`. If not given,
this will default to ``compute_d
typ
e``. Note that only the device
this will default to ``compute_d
evic
e``. Note that only the device
type will be respected (e.g., "cuda:0" and "cuda:1" are the same).
clear_autocast_cache (bool):
When using mixed precision training with `torch.amp.autocast`, if the model weights
...
...
@@ -2532,23 +2532,49 @@ def _post_state_dict_hook(
if
state_dict_on_rank_0_only
and
dist
.
get_rank
()
!=
0
:
state_dict
.
clear
()
return
state_dict
# Assuming we are in a ``summon_full_params()`` context, we need to clone
def
apply_to_tensor
(
obj
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Apply needed operations on a tensor."""
assert
isinstance
(
obj
,
torch
.
Tensor
),
f
"Expect a tensor, got
{
type
(
obj
)
}
"
# Already applied?
if
getattr
(
obj
,
"_has_been_cloned"
,
False
):
return
obj
if
obj
.
device
.
type
!=
module
.
state_dict_device
.
type
:
# Move to right device. This is often used to save GPU memory.
obj
=
obj
.
to
(
device
=
module
.
state_dict_device
)
elif
module
.
training_state
==
TrainingState
.
SUMMON_FULL_PARAMS
:
# If we are in a ``summon_full_params()`` context, we need to clone
# each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times
# recursively, so we need to make sure that we only clone each tensor at
# most once. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed.
#
# "elif" because .to() clones the object too.
obj
=
obj
.
clone
()
# Both .to() and .clone() copies a new object. So we set this flag.
obj
.
_has_been_cloned
=
True
return
obj
# State_dict is supposed to be a flat dict (not nested). The
# keys are encoded with hierarchy. Therefore, we can loop
# over the dict here. (See else case below for additional notes.)
for
key
in
state_dict
.
keys
():
if
not
key
.
startswith
(
prefix
)
or
getattr
(
state_dict
[
key
],
"_has_been_cloned"
,
False
):
# Skip keys without right prefix.
if
not
key
.
startswith
(
prefix
):
continue
if
state_dict
[
key
].
device
.
type
!=
module
.
state_dict_device
.
type
:
state_dict
[
key
]
=
state_dict
[
key
].
to
(
device
=
module
.
state_dict_device
)
state_dict
[
key
].
_has_been_cloned
=
True
elif
module
.
training_state
==
TrainingState
.
SUMMON_FULL_PARAMS
:
# We copy the state_dict since full param will be freed after we
# exit the ``summon_full_params()`` context.
state_dict
[
key
]
=
state_dict
[
key
].
clone
()
state_dict
[
key
].
_has_been_cloned
=
True
elif
isinstance
(
state_dict
[
key
],
torch
.
Tensor
):
state_dict
[
key
]
=
apply_to_tensor
(
state_dict
[
key
])
else
:
# For example, EMA module from data2vec is a dict of tensors.
logging
.
warning
(
f
"Got an unexpected data type in state_dict"
f
"key=
{
key
}
value_type=
{
type
(
state_dict
[
key
])
}
"
)
# Remove "_fsdp_wrapped_module." prefix
replace_by_prefix_
(
state_dict
,
prefix
+
"_fsdp_wrapped_module."
,
prefix
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment