Unverified Commit d1fcc90a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix from_pretrained with default base_model_prefix (#15814)

parent 7f921bcf
...@@ -1580,8 +1580,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1580,8 +1580,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
loaded_keys = list(state_dict.keys()) loaded_keys = list(state_dict.keys())
prefix = model.base_model_prefix prefix = model.base_model_prefix
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) if len(prefix) > 0:
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
else:
has_prefix_module = False
expects_prefix_module = False
# key re-naming operations are never done on the keys # key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model # that are loaded, but always on the keys of the newly initialized model
...@@ -1669,9 +1673,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1669,9 +1673,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Make sure we are able to load base models as well as derived models (with heads) # Make sure we are able to load base models as well as derived models (with heads)
start_prefix = "" start_prefix = ""
model_to_load = model model_to_load = model
if not hasattr(model, cls.base_model_prefix) and has_prefix_module: if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "." start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not has_prefix_module: if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix) model_to_load = getattr(model, cls.base_model_prefix)
if any(key in expected_keys_not_prefixed for key in loaded_keys): if any(key in expected_keys_not_prefixed for key in loaded_keys):
raise ValueError( raise ValueError(
......
...@@ -2105,7 +2105,10 @@ class ModelUtilsTest(TestCasePlus): ...@@ -2105,7 +2105,10 @@ class ModelUtilsTest(TestCasePlus):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir) model.save_pretrained(tmp_dir)
model = NoSuperInitModel.from_pretrained(tmp_dir) new_model = NoSuperInitModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_torch @require_torch
......
...@@ -7,7 +7,6 @@ from .custom_configuration import CustomConfig, NoSuperInitConfig ...@@ -7,7 +7,6 @@ from .custom_configuration import CustomConfig, NoSuperInitConfig
class CustomModel(PreTrainedModel): class CustomModel(PreTrainedModel):
config_class = CustomConfig config_class = CustomConfig
base_model_prefix = "custom"
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -22,7 +21,6 @@ class CustomModel(PreTrainedModel): ...@@ -22,7 +21,6 @@ class CustomModel(PreTrainedModel):
class NoSuperInitModel(PreTrainedModel): class NoSuperInitModel(PreTrainedModel):
config_class = NoSuperInitConfig config_class = NoSuperInitConfig
base_model_prefix = "custom"
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment