Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2020ac4b
Unverified
Commit
2020ac4b
authored
Feb 09, 2023
by
Sylvain Gugger
Committed by
GitHub
Feb 09, 2023
Browse files
Fix from_pretrained API with config and state_dict (#21542)
parent
1efe9c0b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
1 deletion
+13
-1
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+4
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+9
-0
No files found.
src/transformers/modeling_utils.py
View file @
2020ac4b
...
...
@@ -2770,7 +2770,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
del
state_dict
[
checkpoint_key
]
return
mismatched_keys
folder
=
os
.
path
.
sep
.
join
(
resolved_archive_file
[
0
].
split
(
os
.
path
.
sep
)[:
-
1
])
if
resolved_archive_file
is
not
None
:
folder
=
os
.
path
.
sep
.
join
(
resolved_archive_file
[
0
].
split
(
os
.
path
.
sep
)[:
-
1
])
else
:
folder
=
None
if
device_map
is
not
None
and
is_safetensors
:
param_device_map
=
expand_device_map
(
device_map
,
original_loaded_keys
)
...
...
tests/test_modeling_common.py
View file @
2020ac4b
...
...
@@ -2749,6 +2749,15 @@ class ModelUtilsTest(TestCasePlus):
BertModel
.
from_pretrained
(
TINY_T5
)
self
.
assertTrue
(
"You are using a model of type t5 to instantiate a model of type bert"
in
cl
.
out
)
def
test_model_from_pretrained_no_checkpoint
(
self
):
config
=
BertConfig
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
model
=
BertModel
(
config
)
state_dict
=
model
.
state_dict
()
new_model
=
BertModel
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
config
=
config
,
state_dict
=
state_dict
)
for
p1
,
p2
in
zip
(
model
.
parameters
(),
new_model
.
parameters
()):
self
.
assertTrue
(
torch
.
equal
(
p1
,
p2
))
@
require_torch
def
test_model_from_config_torch_dtype
(
self
):
# test that the model can be instantiated with dtype of user's choice - as long as it's a
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment