Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
12bfa97a
Unverified
Commit
12bfa97a
authored
Apr 13, 2022
by
Stas Bekman
Committed by
GitHub
Apr 13, 2022
Browse files
[from_pretrained] refactor find_mismatched_keys (#16706)
parent
9f8bfe70
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
22 deletions
+29
-22
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+29
-22
No files found.
src/transformers/modeling_utils.py
View file @
12bfa97a
...
@@ -1966,8 +1966,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1966,8 +1966,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"properly saved?"
"properly saved?"
)
)
if
state_dict
is
not
None
:
def
_find_mismatched_keys
(
# Whole checkpoint
state_dict
,
model_state_dict
,
loaded_keys
,
add_prefix_to_model
,
remove_prefix_from_model
,
ignore_mismatched_sizes
,
):
mismatched_keys
=
[]
mismatched_keys
=
[]
if
ignore_mismatched_sizes
:
if
ignore_mismatched_sizes
:
for
checkpoint_key
in
loaded_keys
:
for
checkpoint_key
in
loaded_keys
:
...
@@ -1988,6 +1994,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1988,6 +1994,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
)
del
state_dict
[
checkpoint_key
]
del
state_dict
[
checkpoint_key
]
return
mismatched_keys
if
state_dict
is
not
None
:
# Whole checkpoint
mismatched_keys
=
_find_mismatched_keys
(
state_dict
,
model_state_dict
,
loaded_keys
,
add_prefix_to_model
,
remove_prefix_from_model
,
ignore_mismatched_sizes
,
)
error_msgs
=
_load_state_dict_into_model
(
model_to_load
,
state_dict
,
start_prefix
)
error_msgs
=
_load_state_dict_into_model
(
model_to_load
,
state_dict
,
start_prefix
)
else
:
else
:
# Sharded checkpoint
# Sharded checkpoint
...
@@ -1996,30 +2014,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1996,30 +2014,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
resolved_archive_file
=
[
resolved_archive_file
]
resolved_archive_file
=
[
resolved_archive_file
]
error_msgs
=
[]
error_msgs
=
[]
mismatched_keys
=
[]
for
shard_file
in
resolved_archive_file
:
for
shard_file
in
resolved_archive_file
:
state_dict
=
load_state_dict
(
shard_file
)
state_dict
=
load_state_dict
(
shard_file
)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
# matching the weights in the model.
mismatched_keys
=
[]
mismatched_keys
+=
_find_mismatched_keys
(
if
ignore_mismatched_sizes
:
state_dict
,
for
checkpoint_key
in
loaded_keys
:
model_state_dict
,
model_key
=
checkpoint_key
loaded_keys
,
if
remove_prefix_from_model
:
add_prefix_to_model
,
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
remove_prefix_from_model
,
model_key
=
f
"
{
prefix
}
.
{
checkpoint_key
}
"
ignore_mismatched_sizes
,
elif
add_prefix_to_model
:
)
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key
=
"."
.
join
(
checkpoint_key
.
split
(
"."
)[
1
:])
if
(
model_key
in
model_state_dict
and
state_dict
[
checkpoint_key
].
shape
!=
model_state_dict
[
model_key
].
shape
):
mismatched_keys
.
append
(
(
checkpoint_key
,
state_dict
[
checkpoint_key
].
shape
,
model_state_dict
[
model_key
].
shape
)
)
del
state_dict
[
checkpoint_key
]
error_msgs
+=
_load_state_dict_into_model
(
model_to_load
,
state_dict
,
start_prefix
)
error_msgs
+=
_load_state_dict_into_model
(
model_to_load
,
state_dict
,
start_prefix
)
if
len
(
error_msgs
)
>
0
:
if
len
(
error_msgs
)
>
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment