Unverified Commit 04b2f13c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

🚨🚨🚨 Enforce single model initialization (#21431)

* Enforce single model initialization

* Add OneFormer example for problem 3

* Do it the Stas way

* Actually rename the uses...

* Rewrite test

* Try to change the test this way

* Fix all init slow/fast tests

* Break connection

* Fix more tests

* Fix test for initialization

* Remove custom test

* Quality

* Fix last failing tests

* The end?
parent 2020ac4b
......@@ -492,6 +492,48 @@ model = BrandNewBertModel(BrandNewBertConfig())
The above command will create a model according to the default parameters as defined in `BrandNewBertConfig()` with
random weights, thus making sure that the `init()` methods of all components works.
Note that all random initialization should happen in the `_init_weights` method of your `BrandnewBertPreTrainedModel`
class. It should initialize all leaf modules depending on the variables of the config. Here is an example with the
BERT `_init_weights` method:
```py
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
```
You can have some more custom schemes if you need a special initialization for some modules. For instance, in
`Wav2Vec2ForPreTraining`, the last two linear layers need to have the initialization of the regular PyTorch `nn.Linear`
but all the other ones should use an initialization as above. This is coded like this:
```py
def _init_weights(self, module):
"""Initialize the weights"""
if isinstnace(module, Wav2Vec2ForPreTraining):
module.project_hid.reset_parameters()
module.project_q.reset_parameters()
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
```
The `_is_hf_initialized` flag is internally used to make sure we only initialize a submodule once. By setting it to
`True` for `module.project_q` and `module.project_hid`, we make sure the custom initialization we did is not overridden later on,
the `_init_weights` function won't be applied to them.
**6. Write a conversion script**
Next, you should write a conversion script that lets you convert the checkpoint you used to debug *brand_new_bert* in
......
......@@ -436,6 +436,17 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
)
def set_initialized_submodules(model, state_dict_keys):
"""
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
dict.
"""
for module_name, module in model.named_modules():
loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")]
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
module._is_hf_initialized = True
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
......@@ -1176,7 +1187,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
Initialize the weights. This method should be overridden by derived class.
"""
raise NotImplementedError(f"Make sure `_init_weights` is implemented for {self.__class__}")
pass
def _initialize_weights(self, module):
"""
Initialize the weights if they are not already initialized.
"""
if getattr(module, "_is_hf_initialized", False):
return
self._init_weights(module)
module._is_hf_initialized = True
def tie_weights(self):
"""
......@@ -1505,7 +1525,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
def init_weights(self):
"""
If needed prunes and maybe initializes weights.
If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
initialization logic in `_init_weights`.
"""
# Prune heads if needed
if self.config.pruned_heads:
......@@ -1513,7 +1534,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if _init_weights:
# Initialize weights
self.apply(self._init_weights)
self.apply(self._initialize_weights)
# Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways
......@@ -2713,11 +2734,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
uninitialized_modules = model.retrieve_modules_from_names(
missing_keys, add_prefix=add_prefix_to_model, remove_prefix=remove_prefix_from_model
)
for module in uninitialized_modules:
model._init_weights(module)
if remove_prefix_from_model:
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys]
elif add_prefix_to_model:
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys]
else:
_loaded_keys = loaded_keys
set_initialized_submodules(model, _loaded_keys)
# This will only initialize submodules that are not marked as initialized by the line above.
model.apply(model._initialize_weights)
# Set some modules to fp32 if any
if keep_in_fp32_modules is not None:
......
......@@ -1067,10 +1067,12 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
module.text_projection.weight,
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
)
module.text_projection._is_hf_initialized = True
nn.init.normal_(
module.visual_projection.weight,
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
)
module.visual_projection._is_hf_initialized = True
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
......
......@@ -1473,8 +1473,9 @@ class BartForSequenceClassification(BartPretrainedModel):
config.num_labels,
config.classifier_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......@@ -1601,7 +1602,8 @@ class BartForQuestionAnswering(BartPretrainedModel):
self.model = BartModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......
......@@ -2658,8 +2658,9 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
config.num_labels,
config.classifier_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......@@ -2785,7 +2786,8 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
self.model = BigBirdPegasusModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......
......@@ -1186,6 +1186,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
base_model = FSMTModel(config)
self.model = base_model
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(FSMT_GENERATION_EXAMPLE)
......
......@@ -2543,8 +2543,9 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
config.num_labels,
config.classifier_dropout,
)
self.led._init_weights(self.classification_head.dense)
self.led._init_weights(self.classification_head.out_proj)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......@@ -2672,7 +2673,8 @@ class LEDForQuestionAnswering(LEDPreTrainedModel):
self.led = LEDModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.led._init_weights(self.qa_outputs)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......
......@@ -866,6 +866,9 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels])
# Initialize weights and apply final processing
self.post_init()
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
......
......@@ -1447,8 +1447,9 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
config.num_labels,
config.classifier_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......@@ -1574,7 +1575,8 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
self.model = MBartModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......
......@@ -1610,8 +1610,8 @@ class MvpForSequenceClassification(MvpPreTrainedModel):
config.classifier_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)
# Initialize weights and apply final processing
self.post_init()
def set_lightweight_tuning(self):
self.model.set_lightweight_tuning()
......@@ -1737,7 +1737,8 @@ class MvpForQuestionAnswering(MvpPreTrainedModel):
self.model = MvpModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.model._init_weights(self.qa_outputs)
# Initialize weights and apply final processing
self.post_init()
def set_lightweight_tuning(self):
self.model.set_lightweight_tuning()
......
......@@ -2801,6 +2801,7 @@ class OneFormerPreTrainedModel(PreTrainedModel):
elif isinstance(module, OneFormerTransformerDecoder):
nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std)
nn.init.constant_(module.query_input_projection.bias, 0)
module.query_input_projection._is_hf_initialized = True
elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention):
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads)
......
......@@ -1420,8 +1420,9 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
config.num_labels,
config.classifier_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......
......@@ -301,6 +301,12 @@ class UperNetPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
if isinstance(module, UperNetPreTrainedModel):
module.backbone.init_weights()
module.decode_head.init_weights()
module.auxiliary_head.init_weights()
def init_weights(self):
"""Initialize the weights"""
self.backbone.init_weights()
......
......@@ -1049,8 +1049,14 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
if isinstance(module, Wav2Vec2ForPreTraining):
module.project_hid.reset_parameters()
module.project_q.reset_parameters()
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
# gumbel softmax requires special init
if isinstance(module, Wav2Vec2GumbelVectorQuantizer):
elif isinstance(module, Wav2Vec2GumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors)
......@@ -1345,13 +1351,12 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
# Initialize weights and apply final processing
self.post_init()
# make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
# Initialize weights and apply final processing
self.post_init()
def set_gumbel_temperature(self, temperature: int):
"""
Set the Gumbel softmax temperature to a given value. Only necessary for training
......
......@@ -1089,8 +1089,14 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
# Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
if isinstance(module, Wav2Vec2ConformerForPreTraining):
module.project_hid.reset_parameters()
module.project_q.reset_parameters()
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
# gumbel softmax requires special init
if isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors)
......@@ -1381,13 +1387,12 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
# Initialize weights and apply final processing
self.post_init()
# make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
def set_gumbel_temperature(self, temperature: int):
"""
......
......@@ -962,7 +962,6 @@ class WavLMAdapterLayer(nn.Module):
return hidden_states
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel with Wav2Vec2->WavLM, wav2vec2->wavlm
class WavLMPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
......
......@@ -1496,3 +1496,6 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
def test_save_load_fast_init_from_base(self):
pass
......@@ -410,8 +410,13 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
for name, module in model.named_modules():
if module.__class__.__name__ == "DetaBackboneWithPositionalEncodings":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters():
if param.requires_grad:
if param.requires_grad:
if (
"level_embed" in name
......@@ -419,6 +424,7 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
or "value_proj" in name
or "output_proj" in name
or "reference_points" in name
or name in backbone_params
):
continue
self.assertIn(
......
......@@ -24,7 +24,7 @@ from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available():
......@@ -242,6 +242,29 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss
loss.backward()
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
backbone_params = []
for name, module in model.named_modules():
if module.__class__.__name__ == "DPTViTHybridEmbeddings":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters():
if param.requires_grad:
if name in backbone_params:
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@slow
def test_model_from_pretrained(self):
for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
......@@ -24,7 +24,7 @@ from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_torch_available():
......@@ -256,6 +256,29 @@ class DPTModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss
loss.backward()
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
backbone_params = []
for name, module in model.named_modules():
if module.__class__.__name__ == "DPTViTHybridEmbeddings":
backbone_params = [f"{name}.{key}" for key in module.state_dict().keys()]
break
for name, param in model.named_parameters():
if param.requires_grad:
if name in backbone_params:
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@slow
def test_model_from_pretrained(self):
for model_name in DPT_PRETRAINED_MODEL_ARCHIVE_LIST[1:]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment