Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d4ba6e1a
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7addc9346c89563c0d36b30fa3534c58d3a1de05"
Unverified
Commit
d4ba6e1a
authored
Feb 14, 2023
by
Sylvain Gugger
Committed by
GitHub
Feb 14, 2023
Browse files
Fix generation config for empty state dict (#21630)
parent
31728292
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
10 deletions
+13
-10
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+1
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+12
-9
No files found.
src/transformers/modeling_utils.py
View file @
d4ba6e1a
...
@@ -2648,7 +2648,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -2648,7 +2648,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_from_pipeline
=
from_pipeline
,
_from_pipeline
=
from_pipeline
,
**
kwargs
,
**
kwargs
,
)
)
except
OSError
:
except
(
OSError
,
TypeError
)
:
logger
.
info
(
logger
.
info
(
"Generation config file not found, using a generation config created from the model config."
"Generation config file not found, using a generation config created from the model config."
)
)
...
...
tests/test_modeling_common.py
View file @
d4ba6e1a
...
@@ -325,6 +325,18 @@ class ModelTesterMixin:
...
@@ -325,6 +325,18 @@ class ModelTesterMixin:
else
:
else
:
check_save_load
(
first
,
second
)
check_save_load
(
first
,
second
)
def
test_from_pretrained_no_checkpoint
(
self
):
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
state_dict
=
model
.
state_dict
()
new_model
=
model_class
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
config
=
config
,
state_dict
=
state_dict
)
for
p1
,
p2
in
zip
(
model
.
parameters
(),
new_model
.
parameters
()):
self
.
assertTrue
(
torch
.
equal
(
p1
,
p2
))
def
test_save_load_keys_to_ignore_on_save
(
self
):
def
test_save_load_keys_to_ignore_on_save
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
@@ -2776,15 +2788,6 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -2776,15 +2788,6 @@ class ModelUtilsTest(TestCasePlus):
BertModel
.
from_pretrained
(
TINY_T5
)
BertModel
.
from_pretrained
(
TINY_T5
)
self
.
assertTrue
(
"You are using a model of type t5 to instantiate a model of type bert"
in
cl
.
out
)
self
.
assertTrue
(
"You are using a model of type t5 to instantiate a model of type bert"
in
cl
.
out
)
def
test_model_from_pretrained_no_checkpoint
(
self
):
config
=
BertConfig
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
model
=
BertModel
(
config
)
state_dict
=
model
.
state_dict
()
new_model
=
BertModel
.
from_pretrained
(
pretrained_model_name_or_path
=
None
,
config
=
config
,
state_dict
=
state_dict
)
for
p1
,
p2
in
zip
(
model
.
parameters
(),
new_model
.
parameters
()):
self
.
assertTrue
(
torch
.
equal
(
p1
,
p2
))
def
test_model_from_config_torch_dtype
(
self
):
def
test_model_from_config_torch_dtype
(
self
):
# test that the model can be instantiated with dtype of user's choice - as long as it's a
# test that the model can be instantiated with dtype of user's choice - as long as it's a
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment