Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
645f1742
Unverified
Commit
645f1742
authored
Sep 09, 2022
by
Sylvain Gugger
Committed by
GitHub
Sep 09, 2022
Browse files
Exit early in load if no weights are in the sharded state dict (#18937)
parent
660e0b97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
16 deletions
+19
-16
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+19
-16
No files found.
src/transformers/modeling_utils.py
View file @
645f1742
...
...
@@ -418,22 +418,25 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
def
load
(
module
:
nn
.
Module
,
state_dict
,
prefix
=
""
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
args
=
(
state_dict
,
prefix
,
local_metadata
,
True
,
[],
[],
error_msgs
)
if
is_deepspeed_zero3_enabled
():
import
deepspeed
# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters
=
dict
(
module
.
named_parameters
(
prefix
=
prefix
[:
-
1
],
recurse
=
False
))
params_to_gather
=
[
named_parameters
[
k
]
for
k
in
state_dict
.
keys
()
if
k
in
named_parameters
]
if
len
(
params_to_gather
)
>
0
:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with
deepspeed
.
zero
.
GatheredParameters
(
params_to_gather
,
modifier_rank
=
0
):
if
torch
.
distributed
.
get_rank
()
==
0
:
module
.
_load_from_state_dict
(
*
args
)
else
:
module
.
_load_from_state_dict
(
*
args
)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if
len
([
key
for
key
in
state_dict
if
key
.
startswith
(
prefix
)])
>
0
:
if
is_deepspeed_zero3_enabled
():
import
deepspeed
# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters
=
dict
(
module
.
named_parameters
(
prefix
=
prefix
[:
-
1
],
recurse
=
False
))
params_to_gather
=
[
named_parameters
[
k
]
for
k
in
state_dict
.
keys
()
if
k
in
named_parameters
]
if
len
(
params_to_gather
)
>
0
:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with
deepspeed
.
zero
.
GatheredParameters
(
params_to_gather
,
modifier_rank
=
0
):
if
torch
.
distributed
.
get_rank
()
==
0
:
module
.
_load_from_state_dict
(
*
args
)
else
:
module
.
_load_from_state_dict
(
*
args
)
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment