Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
645f1742
Unverified
Commit
645f1742
authored
Sep 09, 2022
by
Sylvain Gugger
Committed by
GitHub
Sep 09, 2022
Browse files
Exit early in load if no weights are in the sharded state dict (#18937)
parent
660e0b97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
16 deletions
+19
-16
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+19
-16
No files found.
src/transformers/modeling_utils.py
View file @
645f1742
...
@@ -418,22 +418,25 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
...
@@ -418,22 +418,25 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
def
load
(
module
:
nn
.
Module
,
state_dict
,
prefix
=
""
):
def
load
(
module
:
nn
.
Module
,
state_dict
,
prefix
=
""
):
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
local_metadata
=
{}
if
metadata
is
None
else
metadata
.
get
(
prefix
[:
-
1
],
{})
args
=
(
state_dict
,
prefix
,
local_metadata
,
True
,
[],
[],
error_msgs
)
args
=
(
state_dict
,
prefix
,
local_metadata
,
True
,
[],
[],
error_msgs
)
if
is_deepspeed_zero3_enabled
():
# Parameters of module and children will start with prefix. We can exit early if there are none in this
import
deepspeed
# state_dict
if
len
([
key
for
key
in
state_dict
if
key
.
startswith
(
prefix
)])
>
0
:
# In sharded models, each shard has only part of the full state_dict, so only gather
if
is_deepspeed_zero3_enabled
():
# parameters that are in the current state_dict.
import
deepspeed
named_parameters
=
dict
(
module
.
named_parameters
(
prefix
=
prefix
[:
-
1
],
recurse
=
False
))
params_to_gather
=
[
named_parameters
[
k
]
for
k
in
state_dict
.
keys
()
if
k
in
named_parameters
]
# In sharded models, each shard has only part of the full state_dict, so only gather
if
len
(
params_to_gather
)
>
0
:
# parameters that are in the current state_dict.
# because zero3 puts placeholders in model params, this context
named_parameters
=
dict
(
module
.
named_parameters
(
prefix
=
prefix
[:
-
1
],
recurse
=
False
))
# manager gathers (unpartitions) the params of the current layer, then loads from
params_to_gather
=
[
named_parameters
[
k
]
for
k
in
state_dict
.
keys
()
if
k
in
named_parameters
]
# the state dict and then re-partitions them again
if
len
(
params_to_gather
)
>
0
:
with
deepspeed
.
zero
.
GatheredParameters
(
params_to_gather
,
modifier_rank
=
0
):
# because zero3 puts placeholders in model params, this context
if
torch
.
distributed
.
get_rank
()
==
0
:
# manager gathers (unpartitions) the params of the current layer, then loads from
module
.
_load_from_state_dict
(
*
args
)
# the state dict and then re-partitions them again
else
:
with
deepspeed
.
zero
.
GatheredParameters
(
params_to_gather
,
modifier_rank
=
0
):
module
.
_load_from_state_dict
(
*
args
)
if
torch
.
distributed
.
get_rank
()
==
0
:
module
.
_load_from_state_dict
(
*
args
)
else
:
module
.
_load_from_state_dict
(
*
args
)
for
name
,
child
in
module
.
_modules
.
items
():
for
name
,
child
in
module
.
_modules
.
items
():
if
child
is
not
None
:
if
child
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment