"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "5e09af2acde21f232a6ed2ad2972c8f2269dcecf"
Unverified Commit 49b77b89 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Quality (#20002)

parent c6c9db3d
...@@ -2467,7 +2467,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2467,7 +2467,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
start_prefix = cls.base_model_prefix + "." start_prefix = cls.base_model_prefix + "."
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix) model_to_load = getattr(model, cls.base_model_prefix)
if any(key in expected_keys_not_prefixed for key in loaded_keys): base_model_expected_keys = list(model_to_load.state_dict().keys())
if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys):
raise ValueError( raise ValueError(
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was " "The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
"properly saved?" "properly saved?"
......
...@@ -117,6 +117,36 @@ if is_torch_available(): ...@@ -117,6 +117,36 @@ if is_torch_available():
) )
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
# Fake pretrained models for tests
class BaseModel(PreTrainedModel):
config_class = PretrainedConfig
def __init__(self, config):
super().__init__(config)
self.linear = nn.Linear(4, 5)
self.linear_2 = nn.Linear(5, 6)
def forward(self, x):
return self.linear_2(self.linear(x))
class ModelWithHead(PreTrainedModel):
base_model_prefix = "base"
config_class = PretrainedConfig
def _init_weights(self, module):
pass
def __init__(self, config):
super().__init__(config)
self.base = BaseModel(config)
# linear is a common name between Base and Head on purpose.
self.linear = nn.Linear(6, 3)
self.linear2 = nn.Linear(3, 5)
def forward(self, x):
return self.linear2(self.linear(self.base(x)))
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
...@@ -3039,6 +3069,28 @@ class ModelUtilsTest(TestCasePlus): ...@@ -3039,6 +3069,28 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()): for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
self.assertTrue(torch.allclose(p1, p2)) self.assertTrue(torch.allclose(p1, p2))
def test_base_model_to_head_model_load(self):
base_model = BaseModel(PretrainedConfig())
with tempfile.TemporaryDirectory() as tmp_dir:
base_model.save_pretrained(tmp_dir)
# Can load a base model in a model with head
model = ModelWithHead.from_pretrained(tmp_dir)
for p1, p2 in zip(model.base.parameters(), base_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
# It doesn't work if the state dict has a mix of keys of the head and base without prefix though.
base_state_dict = base_model.state_dict()
head_state_dict = model.state_dict()
base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
torch.save(base_state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
with self.assertRaisesRegex(
ValueError, "The state dictionary of the model you are trying to load is corrupted."
):
_ = ModelWithHead.from_pretrained(tmp_dir)
@require_torch @require_torch
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment