Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
12bfa97a
Unverified
Commit
12bfa97a
authored
Apr 13, 2022
by
Stas Bekman
Committed by
GitHub
Apr 13, 2022
Browse files
[from_pretrained] refactor find_mismatched_keys (#16706)
parent
9f8bfe70
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
22 deletions
+29
-22
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+29
-22
No files found.
src/transformers/modeling_utils.py
View file @
12bfa97a
...
@@ -1966,8 +1966,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1966,8 +1966,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"properly saved?"
"properly saved?"
)
)
if
state_dict
is
not
None
:
def
_find_mismatched_keys
(
# Whole checkpoint
state_dict
,
model_state_dict
,
loaded_keys
,
add_prefix_to_model
,
remove_prefix_from_model
,
ignore_mismatched_sizes
,
):
mismatched_keys
=
[]
mismatched_keys
=
[]
if
ignore_mismatched_sizes
:
if
ignore_mismatched_sizes
:
for
checkpoint_key
in
loaded_keys
:
for
checkpoint_key
in
loaded_keys
:
...
@@ -1988,6 +1994,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1988,6 +1994,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
)
del
state_dict
[
checkpoint_key
]
del
state_dict
[
checkpoint_key
]
return
mismatched_keys
if
state_dict
is
not
None
:
# Whole checkpoint
mismatched_keys
=
_find_mismatched_keys
(
state_dict
,
model_state_dict
,
loaded_keys
,
add_prefix_to_model
,
remove_prefix_from_model
,
ignore_mismatched_sizes
,
)
error_msgs
=
_load_state_dict_into_model
(
model_to_load
,
state_dict
,
start_prefix
)
error_msgs
=
_load_state_dict_into_model
(
model_to_load
,
state_dict
,
start_prefix
)
else
:
else
:
# Sharded checkpoint
# Sharded checkpoint
...
@@ -1996,30 +2014,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1996,30 +2014,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
resolved_archive_file
=
[
resolved_archive_file
]
resolved_archive_file
=
[
resolved_archive_file
]
error_msgs
=
[]
error_msgs
=
[]
mismatched_keys
=
[]
for
shard_file
in
resolved_archive_file
:
for
shard_file
in
resolved_archive_file
:
state_dict
=
load_state_dict
(
shard_file
)
state_dict
=
load_state_dict
(
shard_file
)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
# matching the weights in the model.
mismatched_keys
=
[]
mismatched_keys
+=
_find_mismatched_keys
(
if
ignore_mismatched_sizes
:
state_dict
,
for
checkpoint_key
in
loaded_keys
:
model_state_dict
,
model_key
=
checkpoint_key
loaded_keys
,
if
remove_prefix_from_model
:
add_prefix_to_model
,
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
remove_prefix_from_model
,
model_key
=
f
"
{
prefix
}
.
{
checkpoint_key
}
"
ignore_mismatched_sizes
,
elif
add_prefix_to_model
:
)
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key
=
"."
.
join
(
checkpoint_key
.
split
(
"."
)[
1
:])
if
(
model_key
in
model_state_dict
and
state_dict
[
checkpoint_key
].
shape
!=
model_state_dict
[
model_key
].
shape
):
mismatched_keys
.
append
(
(
checkpoint_key
,
state_dict
[
checkpoint_key
].
shape
,
model_state_dict
[
model_key
].
shape
)
)
del
state_dict
[
checkpoint_key
]
error_msgs
+=
_load_state_dict_into_model
(
model_to_load
,
state_dict
,
start_prefix
)
error_msgs
+=
_load_state_dict_into_model
(
model_to_load
,
state_dict
,
start_prefix
)
if
len
(
error_msgs
)
>
0
:
if
len
(
error_msgs
)
>
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment