Unverified Commit 04b2f13c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

🚨🚨🚨 Enforce single model initialization (#21431)

* Enforce single model initialization

* Add OneFormer example for problem 3

* Do it the Stas way

* Actually rename the uses...

* Rewrite test

* Try to change the test this way

* Fix all init slow/fast tests

* Break connection

* Fix more tests

* Fix test for initialization

* Remove custom test

* Quality

* Fix last failing tests

* The end?
parent 2020ac4b
...@@ -15,9 +15,6 @@ ...@@ -15,9 +15,6 @@
""" Testing suite for the PyTorch LayoutLMv2 model. """ """ Testing suite for the PyTorch LayoutLMv2 model. """
import os
import random
import tempfile
import unittest import unittest
from transformers.testing_utils import require_detectron2, require_torch, require_torch_multi_gpu, slow, torch_device from transformers.testing_utils import require_detectron2, require_torch, require_torch_multi_gpu, slow, torch_device
...@@ -31,7 +28,6 @@ if is_torch_available(): ...@@ -31,7 +28,6 @@ if is_torch_available():
import torch import torch
from transformers import ( from transformers import (
MODEL_MAPPING,
LayoutLMv2Config, LayoutLMv2Config,
LayoutLMv2ForQuestionAnswering, LayoutLMv2ForQuestionAnswering,
LayoutLMv2ForSequenceClassification, LayoutLMv2ForSequenceClassification,
...@@ -312,54 +308,6 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -312,54 +308,6 @@ class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple):
base_class = base_class[0]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(model_class):
pass
model_class_copy = CopyClass
# make sure that all keys are expected for test
model_class_copy._keys_to_ignore_on_load_missing = []
# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
model_class_copy._init_weights = self._mock_init_weights
model = base_class(config)
state_dict = model.state_dict()
# this will often delete a single weight of a multi-weight module
# to test an edge case
random_key_to_del = random.choice(list(state_dict.keys()))
del state_dict[random_key_to_del]
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin"))
model_fast_init = model_class_copy.from_pretrained(tmpdirname)
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
if key == "layoutlmv2.visual_segment_embedding":
# we skip the visual segment embedding as it has a custom initialization scheme
continue
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
......
...@@ -436,10 +436,10 @@ class ProphetNetModelTester: ...@@ -436,10 +436,10 @@ class ProphetNetModelTester:
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
labels=lm_labels, labels=lm_labels,
) )
self.parent.assertTrue(torch.allclose(result.loss, torch.tensor(4.5819, device=torch_device), atol=1e-3)) self.parent.assertTrue(torch.allclose(result.loss, torch.tensor(4.5981, device=torch_device), atol=1e-3))
expected_logit_slice = torch.tensor( expected_logit_slice = torch.tensor(
[-0.1565, 0.0418, 0.1207, 0.0030, 0.0665, 0.0467, 0.0412], device=torch_device [-0.0648, 0.0790, 0.0360, 0.0089, 0.0039, -0.0639, 0.0131], device=torch_device
) )
self.parent.assertTrue(torch.allclose(result.logits[0, :, 1], expected_logit_slice, atol=1e-3)) self.parent.assertTrue(torch.allclose(result.logits[0, :, 1], expected_logit_slice, atol=1e-3))
......
...@@ -1145,10 +1145,11 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -1145,10 +1145,11 @@ class ReformerIntegrationTests(unittest.TestCase):
hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0] hidden_states = model(input_ids=input_ids, attention_mask=attn_mask)[0]
output_slice = hidden_states[1, -1, :5] output_slice = hidden_states[1, -1, :5]
expected_output_slice = torch.tensor( expected_output_slice = torch.tensor(
[0.0256, -0.0121, 0.0636, 0.0024, -0.0393], [0.1018, -0.2026, 0.2116, 0.0270, -0.1233],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_local_lm_model_grad(self): def test_local_lm_model_grad(self):
...@@ -1163,25 +1164,25 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -1163,25 +1164,25 @@ class ReformerIntegrationTests(unittest.TestCase):
input_ids, _ = self._get_input_ids_and_mask() input_ids, _ = self._get_input_ids_and_mask()
loss = model(input_ids=input_ids, labels=input_ids)[0] loss = model(input_ids=input_ids, labels=input_ids)[0]
self.assertTrue(torch.allclose(loss, torch.tensor(5.7786, dtype=torch.float, device=torch_device), atol=1e-3)) self.assertTrue(torch.allclose(loss, torch.tensor(5.8019, dtype=torch.float, device=torch_device), atol=1e-3))
loss.backward() loss.backward()
# check last grads to cover all proable errors # check last grads to cover all proable errors
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
expected_grad_slice_word = torch.tensor( expected_grad_slice_word = torch.tensor(
[-0.0005, 0.0001, 0.0002, 0.0003, 0.0006], [-0.0005, -0.0001, -0.0002, -0.0006, -0.0006],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
expected_grad_slice_pos_fac_1 = torch.tensor( expected_grad_slice_pos_fac_1 = torch.tensor(
[0.0037, -1.3793, -1.0231, -1.5230, -2.5306], [-0.5235, 0.5704, 0.0922, -0.3140, 0.9928],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
expected_grad_slice_pos_fac_2 = torch.tensor( expected_grad_slice_pos_fac_2 = torch.tensor(
[-1.3165, 0.5168, 0.7785, 1.0811, -0.9830], [1.7960, 1.7668, 0.5593, 0.0907, 1.8342],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
...@@ -1203,24 +1204,24 @@ class ReformerIntegrationTests(unittest.TestCase): ...@@ -1203,24 +1204,24 @@ class ReformerIntegrationTests(unittest.TestCase):
input_ids, _ = self._get_input_ids_and_mask() input_ids, _ = self._get_input_ids_and_mask()
loss = model(input_ids=input_ids, labels=input_ids)[0] loss = model(input_ids=input_ids, labels=input_ids)[0]
self.assertTrue(torch.allclose(loss, torch.tensor(5.7819, dtype=torch.float, device=torch_device), atol=1e-3)) self.assertTrue(torch.allclose(loss, torch.tensor(5.7854, dtype=torch.float, device=torch_device), atol=1e-3))
loss.backward() loss.backward()
# check last grads to cover all proable errors # check last grads to cover all proable errors
grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5] grad_slice_word = model.reformer.embeddings.word_embeddings.weight.grad[0, :5]
expected_grad_slice_word = torch.tensor( expected_grad_slice_word = torch.tensor(
[2.6357e-05, 4.3358e-04, -8.4985e-04, 1.0094e-04, 3.8954e-04], [0.0004, 0.0003, 0.0006, -0.0004, 0.0002],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:] grad_slice_position_factor_1 = model.reformer.embeddings.position_embeddings.weights[0][1, 0, -5:]
expected_grad_slice_pos_fac_1 = torch.tensor( expected_grad_slice_pos_fac_1 = torch.tensor(
[-0.0984, 0.6283, 0.4282, 1.2960, 0.6897], [-0.3792, 0.5593, -1.6993, 0.2033, 0.4131],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5] grad_slice_position_factor_2 = model.reformer.embeddings.position_embeddings.weights[1][0, 1, :5]
expected_grad_slice_pos_fac_2 = torch.tensor( expected_grad_slice_pos_fac_2 = torch.tensor(
[0.4626, -0.0231, -0.0172, 0.1081, 0.3805], [-1.4212, -0.3201, -1.1944, 0.1258, 0.2856],
dtype=torch.float, dtype=torch.float,
device=torch_device, device=torch_device,
) )
......
...@@ -23,7 +23,7 @@ from transformers.testing_utils import require_accelerate, require_torch, requir ...@@ -23,7 +23,7 @@ from transformers.testing_utils import require_accelerate, require_torch, requir
from transformers.utils import cached_property, is_torch_available, is_vision_available from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
...@@ -198,6 +198,28 @@ class ViTHybridModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -198,6 +198,28 @@ class ViTHybridModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs) self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
for name, module in model.named_modules():
if module.__class__.__name__ == "ViTHybridPatchEmbeddings":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters():
if param.requires_grad:
if name in backbone_params:
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in VIT_HYBRID_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -69,7 +69,6 @@ from transformers.testing_utils import ( ...@@ -69,7 +69,6 @@ from transformers.testing_utils import (
USER, USER,
CaptureLogger, CaptureLogger,
TestCasePlus, TestCasePlus,
is_flaky,
is_pt_flax_cross_test, is_pt_flax_cross_test,
is_pt_tf_cross_test, is_pt_tf_cross_test,
is_staging_test, is_staging_test,
...@@ -175,6 +174,9 @@ def _config_zero_init(config): ...@@ -175,6 +174,9 @@ def _config_zero_init(config):
for key in configs_no_init.__dict__.keys(): for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key: if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
setattr(configs_no_init, key, 1e-10) setattr(configs_no_init, key, 1e-10)
if isinstance(getattr(configs_no_init, key, None), PretrainedConfig):
no_init_subconfig = _config_zero_init(getattr(configs_no_init, key))
setattr(configs_no_init, key, no_init_subconfig)
return configs_no_init return configs_no_init
...@@ -182,6 +184,31 @@ TINY_T5 = "patrickvonplaten/t5-tiny-random" ...@@ -182,6 +184,31 @@ TINY_T5 = "patrickvonplaten/t5-tiny-random"
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification" TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
def _mock_init_weights(self, module):
for name, param in module.named_parameters(recurse=False):
# Use the first letter of the name to get a value and go from a <> -13 to z <> 12
value = ord(name[0].lower()) - 110
param.data.fill_(value)
def _mock_all_init_weights(self):
# Prune heads if needed
if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)
import transformers.modeling_utils
if transformers.modeling_utils._init_weights:
for module in self.modules():
module._is_hf_initialized = False
# Initialize weights
self.apply(self._initialize_weights)
# Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways
self.tie_weights()
@require_torch @require_torch
class ModelTesterMixin: class ModelTesterMixin:
model_tester = None model_tester = None
...@@ -357,15 +384,10 @@ class ModelTesterMixin: ...@@ -357,15 +384,10 @@ class ModelTesterMixin:
model.gradient_checkpointing_disable() model.gradient_checkpointing_disable()
self.assertFalse(model.is_gradient_checkpointing) self.assertFalse(model.is_gradient_checkpointing)
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
@is_flaky()
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
return
base_class = MODEL_MAPPING[config.__class__] base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple): if isinstance(base_class, tuple):
...@@ -387,7 +409,8 @@ class ModelTesterMixin: ...@@ -387,7 +409,8 @@ class ModelTesterMixin:
# make init deterministic, but make sure that # make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless # non-initialized weights throw errors nevertheless
model_class_copy._init_weights = self._mock_init_weights model_class_copy._init_weights = _mock_init_weights
model_class_copy.init_weights = _mock_all_init_weights
model = base_class(config) model = base_class(config)
state_dict = model.state_dict() state_dict = model.state_dict()
...@@ -404,13 +427,16 @@ class ModelTesterMixin: ...@@ -404,13 +427,16 @@ class ModelTesterMixin:
model_fast_init = model_class_copy.from_pretrained(tmpdirname) model_fast_init = model_class_copy.from_pretrained(tmpdirname)
model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False) model_slow_init = model_class_copy.from_pretrained(tmpdirname, _fast_init=False)
# Before we test anything
for key in model_fast_init.state_dict().keys(): for key in model_fast_init.state_dict().keys():
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
def test_save_load_fast_init_to_base(self): def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
return
base_class = MODEL_MAPPING[config.__class__] base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple): if isinstance(base_class, tuple):
...@@ -432,7 +458,8 @@ class ModelTesterMixin: ...@@ -432,7 +458,8 @@ class ModelTesterMixin:
# make init deterministic, but make sure that # make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless # non-initialized weights throw errors nevertheless
base_class_copy._init_weights = self._mock_init_weights base_class_copy._init_weights = _mock_init_weights
base_class_copy.init_weights = _mock_all_init_weights
model = model_class(config) model = model_class(config)
state_dict = model.state_dict() state_dict = model.state_dict()
...@@ -454,7 +481,7 @@ class ModelTesterMixin: ...@@ -454,7 +481,7 @@ class ModelTesterMixin:
max_diff = torch.max( max_diff = torch.max(
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]) torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
).item() ).item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment