Unverified Commit f5b5c5bd authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Avoid unnecessary warnings when loading pretrained model (#5922)

* Avoid unnecessary warnings when loading pretrained model

* Fix test

* Add other keys to ignore

* keys_to_ignore_at_load -> authorized_missing_keys
parent 29afb576
...@@ -938,6 +938,7 @@ class BartModel(PretrainedBartModel): ...@@ -938,6 +938,7 @@ class BartModel(PretrainedBartModel):
) )
class BartForConditionalGeneration(PretrainedBartModel): class BartForConditionalGeneration(PretrainedBartModel):
base_model_prefix = "model" base_model_prefix = "model"
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
super().__init__(config) super().__init__(config)
......
...@@ -577,6 +577,8 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -577,6 +577,8 @@ class GPT2Model(GPT2PreTrainedModel):
GPT2_START_DOCSTRING, GPT2_START_DOCSTRING,
) )
class GPT2LMHeadModel(GPT2PreTrainedModel): class GPT2LMHeadModel(GPT2PreTrainedModel):
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.transformer = GPT2Model(config) self.transformer = GPT2Model(config)
......
...@@ -1027,6 +1027,8 @@ class T5Model(T5PreTrainedModel): ...@@ -1027,6 +1027,8 @@ class T5Model(T5PreTrainedModel):
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
class T5ForConditionalGeneration(T5PreTrainedModel): class T5ForConditionalGeneration(T5PreTrainedModel):
authorized_missing_keys = [r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.model_dim = config.d_model self.model_dim = config.d_model
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import inspect import inspect
import logging import logging
import os import os
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
...@@ -289,9 +290,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -289,9 +290,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
when loading the model (and avoid unnecessary warnings).
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
authorized_missing_keys = None
@property @property
def dummy_inputs(self) -> Dict[str, torch.Tensor]: def dummy_inputs(self) -> Dict[str, torch.Tensor]:
...@@ -806,9 +810,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -806,9 +810,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
head_model_state_dict_without_base_prefix = [ head_model_state_dict_without_base_prefix = [
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys() key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
] ]
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict) missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
logger.warning( logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
......
...@@ -311,6 +311,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -311,6 +311,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
all_generative_model_classes = ( all_generative_model_classes = (
(GPT2LMHeadModel,) if is_torch_available() else () (GPT2LMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly ) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
test_missing_keys = False
def setUp(self): def setUp(self):
self.model_tester = GPT2ModelTester(self) self.model_tester = GPT2ModelTester(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment