Unverified Commit bac2d29a authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Attempting to test automatically the `_keys_to_ignore`. (#20042)



* Attempting to test automatically the `_keys_to_ignore`.

* Style.

* First fix pass.

* Moving test on its own.

* Another batch.

* Second round removing BatchNorm

* Fixing layoutlmv{2,3} + support older Python.

* Disable miss missing warning.

* Removing dodgy additions.

* Big pass.

* mbart.

* More corrections.

* Fixup.

* Updating test_correct_missing_keys

* Add escape hatch for when the head has no extra params so doesn't need

the missing keys check.

* Fixing test.

* Greener.

* Green ! (except for weird splinter bug).

* Adding a test about `named_parameters` usage.

* Shorten message.

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* After rebase modifications.

* More explicit condition checking.

* Fixing slow tests issues.

* Remove extra pdb.

* Remove print.

* Attempt to make failure consistent + fixing roc_bert.

* Removing the seed  (all tests passing with it).
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent d606d566
...@@ -1296,6 +1296,8 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1296,6 +1296,8 @@ class XLNetModel(XLNetPreTrainedModel):
XLNET_START_DOCSTRING, XLNET_START_DOCSTRING,
) )
class XLNetLMHeadModel(XLNetPreTrainedModel): class XLNetLMHeadModel(XLNetPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_loss.weight"]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.attn_type = config.attn_type self.attn_type = config.attn_type
......
...@@ -852,6 +852,12 @@ class YosoModel(YosoPreTrainedModel): ...@@ -852,6 +852,12 @@ class YosoModel(YosoPreTrainedModel):
@add_start_docstrings("""YOSO Model with a `language modeling` head on top.""", YOSO_START_DOCSTRING) @add_start_docstrings("""YOSO Model with a `language modeling` head on top.""", YOSO_START_DOCSTRING)
class YosoForMaskedLM(YosoPreTrainedModel): class YosoForMaskedLM(YosoPreTrainedModel):
_keys_to_ignore_on_load_missing = [
"cls.predictions.decoder.bias",
"cls.predictions.decoder.weight",
"embeddings.position_ids",
]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
......
...@@ -119,8 +119,6 @@ class AutoModelTest(unittest.TestCase): ...@@ -119,8 +119,6 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, BertForPreTraining) self.assertIsInstance(model, BertForPreTraining)
# Only one value should not be initialized and in the missing keys. # Only one value should not be initialized and in the missing keys.
missing_keys = loading_info.pop("missing_keys")
self.assertListEqual(["cls.predictions.decoder.bias"], missing_keys)
for key, value in loading_info.items(): for key, value in loading_info.items():
self.assertEqual(len(value), 0) self.assertEqual(len(value), 0)
......
...@@ -424,7 +424,6 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -424,7 +424,6 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
is_encoder_decoder = True is_encoder_decoder = True
fx_compatible = True fx_compatible = True
test_pruning = False test_pruning = False
test_missing_keys = False
def setUp(self): def setUp(self):
self.model_tester = BartModelTester(self) self.model_tester = BartModelTester(self)
...@@ -1445,6 +1444,7 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un ...@@ -1445,6 +1444,7 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
fx_comptatible = True fx_comptatible = True
test_pruning = False test_pruning = False
is_encoder_decoder = False is_encoder_decoder = False
test_missing_keys = False
def setUp( def setUp(
self, self,
......
...@@ -1468,11 +1468,24 @@ class ModelTesterMixin: ...@@ -1468,11 +1468,24 @@ class ModelTesterMixin:
base_model_prefix = model.base_model_prefix base_model_prefix = model.base_model_prefix
if hasattr(model, base_model_prefix): if hasattr(model, base_model_prefix):
extra_params = {k: v for k, v in model.named_parameters() if not k.startswith(base_model_prefix)}
extra_params.update({k: v for k, v in model.named_buffers() if not k.startswith(base_model_prefix)})
# Some models define this as None
if model._keys_to_ignore_on_load_missing:
for key in model._keys_to_ignore_on_load_missing:
extra_params.pop(key, None)
if not extra_params:
# In that case, we *are* on a head model, but every
# single key is not actual parameters and this is
# tested in `test_tied_model_weights_key_ignore` test.
continue
with tempfile.TemporaryDirectory() as temp_dir_name: with tempfile.TemporaryDirectory() as temp_dir_name:
model.base_model.save_pretrained(temp_dir_name) model.base_model.save_pretrained(temp_dir_name)
model, loading_info = model_class.from_pretrained(temp_dir_name, output_loading_info=True) model, loading_info = model_class.from_pretrained(temp_dir_name, output_loading_info=True)
with self.subTest(msg=f"Missing keys for {model.__class__.__name__}"): self.assertGreater(len(loading_info["missing_keys"]), 0, model.__class__.__name__)
self.assertGreater(len(loading_info["missing_keys"]), 0)
def test_tie_model_weights(self): def test_tie_model_weights(self):
if not self.test_torchscript: if not self.test_torchscript:
...@@ -1522,6 +1535,54 @@ class ModelTesterMixin: ...@@ -1522,6 +1535,54 @@ class ModelTesterMixin:
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape) # self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
def test_tied_model_weights_key_ignore(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model_tied = model_class(config)
with tempfile.TemporaryDirectory() as d:
model_tied.save_pretrained(d)
# We are nuking ALL weights on file, so every parameter should
# yell on load. We're going to detect if we yell too much, or too little.
with open(os.path.join(d, "pytorch_model.bin"), "wb") as f:
torch.save({}, f)
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
# ! Actually we could use `state_dict()` and check iteratively the tensors which are the same (for instance using `tensor.data_ptr()`). to detect the duplicates.
# ```python
# model = GPT2LMHeadModel.from_pretrained("gpt2")
# "lm_head.weight" in model.state_dict().keys() # True
# "lm_head.weight" in model.named_parameters() # False
# In [6]: model.lm_head.weight.data_ptr()
# Out[6]: 139901378371648
# In [9]: model.transformer.wte.weight.data_ptr()
# Out[9]: 139901378371648 # Same PTR, it's the same DATA ! we would need to check for stride too to be 100% accurate.
# ```
prefix = f"{model_reloaded.base_model_prefix}."
params = dict(model_reloaded.named_parameters())
params.update(dict(model_reloaded.named_buffers()))
# param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
missing_keys = set(infos["missing_keys"])
extra_missing = missing_keys - param_names
# missed_missing = param_names - missing_keys
self.assertEqual(
extra_missing,
set(),
f"This model {model_class.__name__} might be missing some `keys_to_ignore`: {extra_missing}",
)
# self.assertEqual(
# missed_missing,
# set(),
# f"This model {model_class.__name__} ignores keys {missed_missing} but they look like real"
# " parameters",
# )
def test_model_outputs_equivalence(self): def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment