Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
49b77b89
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5e09af2acde21f232a6ed2ad2972c8f2269dcecf"
Unverified
Commit
49b77b89
authored
Nov 02, 2022
by
Sylvain Gugger
Committed by
GitHub
Nov 02, 2022
Browse files
Quality (#20002)
parent
c6c9db3d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
1 deletion
+54
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+2
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+52
-0
No files found.
src/transformers/modeling_utils.py
View file @
49b77b89
...
@@ -2467,7 +2467,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2467,7 +2467,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
start_prefix
=
cls
.
base_model_prefix
+
"."
start_prefix
=
cls
.
base_model_prefix
+
"."
if
len
(
cls
.
base_model_prefix
)
>
0
and
hasattr
(
model
,
cls
.
base_model_prefix
)
and
not
has_prefix_module
:
if
len
(
cls
.
base_model_prefix
)
>
0
and
hasattr
(
model
,
cls
.
base_model_prefix
)
and
not
has_prefix_module
:
model_to_load
=
getattr
(
model
,
cls
.
base_model_prefix
)
model_to_load
=
getattr
(
model
,
cls
.
base_model_prefix
)
if
any
(
key
in
expected_keys_not_prefixed
for
key
in
loaded_keys
):
base_model_expected_keys
=
list
(
model_to_load
.
state_dict
().
keys
())
if
any
(
key
in
expected_keys_not_prefixed
and
key
not
in
base_model_expected_keys
for
key
in
loaded_keys
):
raise
ValueError
(
raise
ValueError
(
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
"properly saved?"
"properly saved?"
...
...
tests/test_modeling_common.py
View file @
49b77b89
...
@@ -117,6 +117,36 @@ if is_torch_available():
...
@@ -117,6 +117,36 @@ if is_torch_available():
)
)
from
transformers.modeling_utils
import
shard_checkpoint
from
transformers.modeling_utils
import
shard_checkpoint
# Fake pretrained models for tests
class
BaseModel
(
PreTrainedModel
):
config_class
=
PretrainedConfig
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
linear
=
nn
.
Linear
(
4
,
5
)
self
.
linear_2
=
nn
.
Linear
(
5
,
6
)
def
forward
(
self
,
x
):
return
self
.
linear_2
(
self
.
linear
(
x
))
class
ModelWithHead
(
PreTrainedModel
):
base_model_prefix
=
"base"
config_class
=
PretrainedConfig
def
_init_weights
(
self
,
module
):
pass
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
base
=
BaseModel
(
config
)
# linear is a common name between Base and Head on purpose.
self
.
linear
=
nn
.
Linear
(
6
,
3
)
self
.
linear2
=
nn
.
Linear
(
3
,
5
)
def
forward
(
self
,
x
):
return
self
.
linear2
(
self
.
linear
(
self
.
base
(
x
)))
if
is_tf_available
():
if
is_tf_available
():
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -3039,6 +3069,28 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -3039,6 +3069,28 @@ class ModelUtilsTest(TestCasePlus):
for
p1
,
p2
in
zip
(
safetensors_model
.
parameters
(),
pytorch_model
.
parameters
()):
for
p1
,
p2
in
zip
(
safetensors_model
.
parameters
(),
pytorch_model
.
parameters
()):
self
.
assertTrue
(
torch
.
allclose
(
p1
,
p2
))
self
.
assertTrue
(
torch
.
allclose
(
p1
,
p2
))
def
test_base_model_to_head_model_load
(
self
):
base_model
=
BaseModel
(
PretrainedConfig
())
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
base_model
.
save_pretrained
(
tmp_dir
)
# Can load a base model in a model with head
model
=
ModelWithHead
.
from_pretrained
(
tmp_dir
)
for
p1
,
p2
in
zip
(
model
.
base
.
parameters
(),
base_model
.
parameters
()):
self
.
assertTrue
(
torch
.
allclose
(
p1
,
p2
))
# It doesn't work if the state dict has a mix of keys of the head and base without prefix though.
base_state_dict
=
base_model
.
state_dict
()
head_state_dict
=
model
.
state_dict
()
base_state_dict
[
"linear2.weight"
]
=
head_state_dict
[
"linear2.weight"
]
base_state_dict
[
"linear2.bias"
]
=
head_state_dict
[
"linear2.bias"
]
torch
.
save
(
base_state_dict
,
os
.
path
.
join
(
tmp_dir
,
WEIGHTS_NAME
))
with
self
.
assertRaisesRegex
(
ValueError
,
"The state dictionary of the model you are trying to load is corrupted."
):
_
=
ModelWithHead
.
from_pretrained
(
tmp_dir
)
@
require_torch
@
require_torch
@
is_staging_test
@
is_staging_test
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment