"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7addc9346c89563c0d36b30fa3534c58d3a1de05"
Unverified Commit d4ba6e1a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix generation config for empty state dict (#21630)

parent 31728292
...@@ -2648,7 +2648,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2648,7 +2648,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_from_pipeline=from_pipeline, _from_pipeline=from_pipeline,
**kwargs, **kwargs,
) )
except OSError: except (OSError, TypeError):
logger.info( logger.info(
"Generation config file not found, using a generation config created from the model config." "Generation config file not found, using a generation config created from the model config."
) )
......
...@@ -325,6 +325,18 @@ class ModelTesterMixin: ...@@ -325,6 +325,18 @@ class ModelTesterMixin:
else: else:
check_save_load(first, second) check_save_load(first, second)
def test_from_pretrained_no_checkpoint(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
state_dict = model.state_dict()
new_model = model_class.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_save_load_keys_to_ignore_on_save(self): def test_save_load_keys_to_ignore_on_save(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -2776,15 +2788,6 @@ class ModelUtilsTest(TestCasePlus): ...@@ -2776,15 +2788,6 @@ class ModelUtilsTest(TestCasePlus):
BertModel.from_pretrained(TINY_T5) BertModel.from_pretrained(TINY_T5)
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out) self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
def test_model_from_pretrained_no_checkpoint(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)
state_dict = model.state_dict()
new_model = BertModel.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=state_dict)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_model_from_config_torch_dtype(self): def test_model_from_config_torch_dtype(self):
# test that the model can be instantiated with dtype of user's choice - as long as it's a # test that the model can be instantiated with dtype of user's choice - as long as it's a
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the # float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment