Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d4ba6e1a
Unverified
Commit
d4ba6e1a
authored
Feb 14, 2023
by
Sylvain Gugger
Committed by
GitHub
Feb 14, 2023
Browse files
Fix generation config for empty state dict (#21630)
parent
31728292
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
10 deletions
+13
-10
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+1
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+12
-9
No files found.
src/transformers/modeling_utils.py
View file @
d4ba6e1a
...
@@ -2648,7 +2648,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2648,7 +2648,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_from_pipeline
=
from_pipeline
,
_from_pipeline
=
from_pipeline
,
**
kwargs
,
**
kwargs
,
)
)
except
OSError
:
except
(
OSError
,
TypeError
)
:
logger
.
info
(
logger
.
info
(
"Generation config file not found, using a generation config created from the model config."
"Generation config file not found, using a generation config created from the model config."
)
)
...
...
tests/test_modeling_common.py
View file @
d4ba6e1a
...
@@ -325,6 +325,18 @@ class ModelTesterMixin:
...
@@ -325,6 +325,18 @@ class ModelTesterMixin:
else
:
else
:
check_save_load
(
first
,
second
)
check_save_load
(
first
,
second
)
def
test_from_pretrained_no_checkpoint
(
self
):
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
state_dict
=
model
.
state_dict
()
new_model
=
model_class
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
config
=
config
,
state_dict
=
state_dict
)
for
p1
,
p2
in
zip
(
model
.
parameters
(),
new_model
.
parameters
()):
self
.
assertTrue
(
torch
.
equal
(
p1
,
p2
))
def
test_save_load_keys_to_ignore_on_save
(
self
):
def
test_save_load_keys_to_ignore_on_save
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
@@ -2776,15 +2788,6 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -2776,15 +2788,6 @@ class ModelUtilsTest(TestCasePlus):
BertModel
.
from_pretrained
(
TINY_T5
)
BertModel
.
from_pretrained
(
TINY_T5
)
self
.
assertTrue
(
"You are using a model of type t5 to instantiate a model of type bert"
in
cl
.
out
)
self
.
assertTrue
(
"You are using a model of type t5 to instantiate a model of type bert"
in
cl
.
out
)
def
test_model_from_pretrained_no_checkpoint
(
self
):
config
=
BertConfig
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
model
=
BertModel
(
config
)
state_dict
=
model
.
state_dict
()
new_model
=
BertModel
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
config
=
config
,
state_dict
=
state_dict
)
for
p1
,
p2
in
zip
(
model
.
parameters
(),
new_model
.
parameters
()):
self
.
assertTrue
(
torch
.
equal
(
p1
,
p2
))
def
test_model_from_config_torch_dtype
(
self
):
def
test_model_from_config_torch_dtype
(
self
):
# test that the model can be instantiated with dtype of user's choice - as long as it's a
# test that the model can be instantiated with dtype of user's choice - as long as it's a
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment