Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0001d056
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "00112c35393e4d81ef7593a3763dc626c0403e7b"
Unverified
Commit
0001d056
authored
Mar 05, 2020
by
Lysandre Debut
Committed by
GitHub
Mar 05, 2020
Browse files
Correct missing keys + test (#3143)
parent
1741d740
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
0 deletions
+24
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+9
-0
tests/test_modeling_common.py
tests/test_modeling_common.py
+15
-0
No files found.
src/transformers/modeling_utils.py
View file @
0001d056
...
@@ -539,6 +539,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -539,6 +539,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
model_to_load
=
getattr
(
model
,
cls
.
base_model_prefix
)
model_to_load
=
getattr
(
model
,
cls
.
base_model_prefix
)
load
(
model_to_load
,
prefix
=
start_prefix
)
load
(
model_to_load
,
prefix
=
start_prefix
)
if
model
.
__class__
.
__name__
!=
model_to_load
.
__class__
.
__name__
:
base_model_state_dict
=
model_to_load
.
state_dict
().
keys
()
head_model_state_dict_without_base_prefix
=
[
key
.
split
(
cls
.
base_model_prefix
+
"."
)[
-
1
]
for
key
in
model
.
state_dict
().
keys
()
]
missing_keys
.
extend
(
head_model_state_dict_without_base_prefix
-
base_model_state_dict
)
if
len
(
missing_keys
)
>
0
:
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
...
...
tests/test_modeling_common.py
View file @
0001d056
...
@@ -526,6 +526,21 @@ class ModelTesterMixin:
...
@@ -526,6 +526,21 @@ class ModelTesterMixin:
x
=
model
.
get_output_embeddings
()
x
=
model
.
get_output_embeddings
()
self
.
assertTrue
(
x
is
None
or
isinstance
(
x
,
torch
.
nn
.
Linear
))
self
.
assertTrue
(
x
is
None
or
isinstance
(
x
,
torch
.
nn
.
Linear
))
def
test_correct_missing_keys
(
self
):
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
base_model_prefix
=
model
.
base_model_prefix
if
hasattr
(
model
,
base_model_prefix
):
with
tempfile
.
TemporaryDirectory
()
as
temp_dir_name
:
model
.
base_model
.
save_pretrained
(
temp_dir_name
)
model
,
loading_info
=
model_class
.
from_pretrained
(
temp_dir_name
,
output_loading_info
=
True
)
with
self
.
subTest
(
msg
=
"Missing keys for {}"
.
format
(
model
.
__class__
.
__name__
)):
self
.
assertGreater
(
len
(
loading_info
[
"missing_keys"
]),
0
)
def
test_tie_model_weights
(
self
):
def
test_tie_model_weights
(
self
):
if
not
self
.
test_torchscript
:
if
not
self
.
test_torchscript
:
return
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment