"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "12313838d33373d06d35b48c3c501fa832f16443"
Unverified Commit d1fcc90a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix from_pretrained with default base_model_prefix (#15814)

parent 7f921bcf
......@@ -1580,8 +1580,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
loaded_keys = list(state_dict.keys())
prefix = model.base_model_prefix
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
else:
has_prefix_module = False
expects_prefix_module = False
# key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model
......@@ -1669,9 +1673,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix)
if any(key in expected_keys_not_prefixed for key in loaded_keys):
raise ValueError(
......
......@@ -2105,7 +2105,10 @@ class ModelUtilsTest(TestCasePlus):
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = NoSuperInitModel.from_pretrained(tmp_dir)
new_model = NoSuperInitModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_torch
......
......@@ -7,7 +7,6 @@ from .custom_configuration import CustomConfig, NoSuperInitConfig
class CustomModel(PreTrainedModel):
config_class = CustomConfig
base_model_prefix = "custom"
def __init__(self, config):
super().__init__(config)
......@@ -22,7 +21,6 @@ class CustomModel(PreTrainedModel):
class NoSuperInitModel(PreTrainedModel):
config_class = NoSuperInitConfig
base_model_prefix = "custom"
def __init__(self, config):
super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment