Unverified Commit d83b0e0c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a post init method to all models (#14431)

* Add a post init method to all models

* Fix tests

* Fix last tests

* Fix templates

* Add comment

* Forgot to save
parent 08816de1
......@@ -503,8 +503,9 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
self.layers = nn.ModuleList([TrOCRDecoderLayer(config) for _ in range(config.decoder_layers)])
self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
......@@ -784,7 +785,8 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel):
self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
......
......@@ -1045,7 +1045,8 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
else:
self.encoder = UniSpeechEncoder(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
......@@ -1165,7 +1166,8 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes)
self.dropout = nn.Dropout(config.final_dropout)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def set_gumbel_temperature(self, temperature: int):
"""
......@@ -1337,7 +1339,8 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
......@@ -1445,7 +1448,8 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
......
......@@ -1046,7 +1046,8 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
else:
self.encoder = UniSpeechSatEncoder(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
......@@ -1171,7 +1172,8 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
if self.config.do_stable_layer_norm:
self.layer_norm_for_extract.requires_grad = False
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def set_gumbel_temperature(self, temperature: int):
"""
......@@ -1328,7 +1330,8 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
......@@ -1436,7 +1439,8 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
......
......@@ -701,7 +701,8 @@ class VisualBertModel(VisualBertPreTrainedModel):
if self.bypass_transformer:
self.additional_layer = VisualBertLayer(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
......@@ -877,7 +878,8 @@ class VisualBertForPreTraining(VisualBertPreTrainedModel):
self.visual_bert = VisualBertModel(config)
self.cls = VisualBertPreTrainingHeads(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
......@@ -1021,7 +1023,8 @@ class VisualBertForMultipleChoice(VisualBertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.cls = nn.Linear(config.hidden_size, 1)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(
VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
......@@ -1170,7 +1173,8 @@ class VisualBertForQuestionAnswering(VisualBertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.cls = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1292,7 +1296,8 @@ class VisualBertForVisualReasoning(VisualBertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.cls = nn.Linear(config.hidden_size, config.num_labels) # 2
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1448,7 +1453,8 @@ class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
self.cls = VisualBertPreTrainingHeads(config)
self.attention = VisualBertRegionToPhraseAttention(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
......
......@@ -487,7 +487,8 @@ class ViTModel(ViTPreTrainedModel):
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = ViTPooler(config) if add_pooling_layer else None
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
......@@ -603,7 +604,8 @@ class ViTForImageClassification(ViTPreTrainedModel):
# Classifier head
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
......
......@@ -1152,7 +1152,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def _mask_hidden_states(
self,
......@@ -1269,7 +1270,8 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
# make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
......@@ -1480,7 +1482,8 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
self.dropout = nn.Dropout(config.final_dropout)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Wav2Vec2BaseModelOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1563,7 +1566,8 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
......@@ -1670,7 +1674,8 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
......
......@@ -469,7 +469,8 @@ class XLMModel(XLMPreTrainedModel):
if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))})
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def get_input_embeddings(self):
......@@ -687,7 +688,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self.transformer = XLMModel(config)
self.pred_layer = XLMPredLayer(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.pred_layer.proj
......@@ -785,7 +787,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
self.transformer = XLMModel(config)
self.sequence_summary = SequenceSummary(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -885,7 +888,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
self.transformer = XLMModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -989,7 +993,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self.transformer = XLMModel(config)
self.qa_outputs = SQuADHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=XLMForQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1108,7 +1113,8 @@ class XLMForTokenClassification(XLMPreTrainedModel):
self.dropout = nn.Dropout(config.dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1201,7 +1207,8 @@ class XLMForMultipleChoice(XLMPreTrainedModel):
self.sequence_summary = SequenceSummary(config)
self.logits_proj = nn.Linear(config.num_labels, 1)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -955,7 +955,8 @@ class XLNetModel(XLNetPreTrainedModel):
self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
self.dropout = nn.Dropout(config.dropout)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.word_embedding
......@@ -1311,7 +1312,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self.transformer = XLNetModel(config)
self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_loss
......@@ -1493,7 +1495,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.sequence_summary = SequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1600,7 +1603,8 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
self.transformer = XLNetModel(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1697,7 +1701,8 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
self.sequence_summary = SequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, 1)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
......@@ -1800,7 +1805,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
self.transformer = XLNetModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1913,7 +1919,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self.end_logits = PoolerEndLogits(config)
self.answer_class = PoolerAnswerClass(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=XLNetForQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
......
......@@ -777,7 +777,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
self.embeddings = {{cookiecutter.camelcase_modelname}}Embeddings(config)
self.encoder = {{cookiecutter.camelcase_modelname}}Encoder(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
......@@ -943,7 +944,8 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
......@@ -1046,7 +1048,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
......@@ -1217,7 +1220,8 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.classifier = {{cookiecutter.camelcase_modelname}}ClassificationHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1309,7 +1313,8 @@ class {{cookiecutter.camelcase_modelname}}ForMultipleChoice({{cookiecutter.camel
self.sequence_summary = SequenceSummary(config)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
......@@ -1399,7 +1404,8 @@ class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter.
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1486,7 +1492,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -2224,8 +2231,9 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
......@@ -2388,8 +2396,9 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}DecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
......@@ -2640,7 +2649,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
self.encoder = {{cookiecutter.camelcase_modelname}}Encoder(config, self.shared)
self.decoder = {{cookiecutter.camelcase_modelname}}Decoder(config, self.shared)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.shared
......@@ -2755,7 +2765,8 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
......@@ -3170,7 +3181,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
......
......@@ -222,14 +222,6 @@ class ModelTesterMixin:
config.gradient_checkpointing = True
model = model_class(config)
# Model does not have gradient checkpointing activated yet, it will be done at the first forward.
self.assertFalse(model.is_gradient_checkpointing)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
_ = model(**inputs)
# Model has gradient checkpointing activated after the first forward.
self.assertTrue(model.is_gradient_checkpointing)
def test_gradient_checkpointing_enable_disable(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment