Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
12bfa97a
Unverified
Commit
12bfa97a
authored
Apr 13, 2022
by
Stas Bekman
Committed by
GitHub
Apr 13, 2022
Browse files
[from_pretrained] refactor find_mismatched_keys (#16706)
parent
9f8bfe70
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
22 deletions
+29
-22
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+29
-22
No files found.
src/transformers/modeling_utils.py
View file @
12bfa97a
...
...
@@ -1966,8 +1966,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"properly saved?"
)
if
state_dict
is
not
None
:
# Whole checkpoint
def
_find_mismatched_keys
(
state_dict
,
model_state_dict
,
loaded_keys
,
add_prefix_to_model
,
remove_prefix_from_model
,
ignore_mismatched_sizes
,
):
mismatched_keys
=
[]
if
ignore_mismatched_sizes
:
for
checkpoint_key
in
loaded_keys
:
...
...
@@ -1988,6 +1994,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
del
state_dict
[
checkpoint_key
]
return
mismatched_keys
if
state_dict
is
not
None
:
# Whole checkpoint
mismatched_keys
=
_find_mismatched_keys
(
state_dict
,
model_state_dict
,
loaded_keys
,
add_prefix_to_model
,
remove_prefix_from_model
,
ignore_mismatched_sizes
,
)
error_msgs
=
_load_state_dict_into_model
(
model_to_load
,
state_dict
,
start_prefix
)
else
:
# Sharded checkpoint
...
...
@@ -1996,30 +2014,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
resolved_archive_file
=
[
resolved_archive_file
]
error_msgs
=
[]
mismatched_keys
=
[]
for
shard_file
in
resolved_archive_file
:
state_dict
=
load_state_dict
(
shard_file
)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys
=
[]
if
ignore_mismatched_sizes
:
for
checkpoint_key
in
loaded_keys
:
model_key
=
checkpoint_key
if
remove_prefix_from_model
:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key
=
f
"
{
prefix
}
.
{
checkpoint_key
}
"
elif
add_prefix_to_model
:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key
=
"."
.
join
(
checkpoint_key
.
split
(
"."
)[
1
:])
if
(
model_key
in
model_state_dict
and
state_dict
[
checkpoint_key
].
shape
!=
model_state_dict
[
model_key
].
shape
):
mismatched_keys
.
append
(
(
checkpoint_key
,
state_dict
[
checkpoint_key
].
shape
,
model_state_dict
[
model_key
].
shape
)
mismatched_keys
+=
_find_mismatched_keys
(
state_dict
,
model_state_dict
,
loaded_keys
,
add_prefix_to_model
,
remove_prefix_from_model
,
ignore_mismatched_sizes
,
)
del
state_dict
[
checkpoint_key
]
error_msgs
+=
_load_state_dict_into_model
(
model_to_load
,
state_dict
,
start_prefix
)
if
len
(
error_msgs
)
>
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment