Unverified Commit d83b0e0c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a post init method to all models (#14431)

* Add a post init method to all models

* Fix tests

* Fix last tests

* Fix templates

* Add comment

* Forgot to save
parent 08816de1
...@@ -503,8 +503,9 @@ class TrOCRDecoder(TrOCRPreTrainedModel): ...@@ -503,8 +503,9 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
self.layers = nn.ModuleList([TrOCRDecoderLayer(config) for _ in range(config.decoder_layers)]) self.layers = nn.ModuleList([TrOCRDecoderLayer(config) for _ in range(config.decoder_layers)])
self.init_weights()
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
...@@ -784,7 +785,8 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel): ...@@ -784,7 +785,8 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel):
self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.model.decoder.embed_tokens return self.model.decoder.embed_tokens
......
...@@ -1045,7 +1045,8 @@ class UniSpeechModel(UniSpeechPreTrainedModel): ...@@ -1045,7 +1045,8 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
else: else:
self.encoder = UniSpeechEncoder(config) self.encoder = UniSpeechEncoder(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states( def _mask_hidden_states(
...@@ -1165,7 +1166,8 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel): ...@@ -1165,7 +1166,8 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes) self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes)
self.dropout = nn.Dropout(config.final_dropout) self.dropout = nn.Dropout(config.final_dropout)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def set_gumbel_temperature(self, temperature: int): def set_gumbel_temperature(self, temperature: int):
""" """
...@@ -1337,7 +1339,8 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel): ...@@ -1337,7 +1339,8 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
) )
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self): def freeze_feature_extractor(self):
""" """
...@@ -1445,7 +1448,8 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): ...@@ -1445,7 +1448,8 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self): def freeze_feature_extractor(self):
""" """
......
...@@ -1046,7 +1046,8 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): ...@@ -1046,7 +1046,8 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
else: else:
self.encoder = UniSpeechSatEncoder(config) self.encoder = UniSpeechSatEncoder(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states( def _mask_hidden_states(
...@@ -1171,7 +1172,8 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): ...@@ -1171,7 +1172,8 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
if self.config.do_stable_layer_norm: if self.config.do_stable_layer_norm:
self.layer_norm_for_extract.requires_grad = False self.layer_norm_for_extract.requires_grad = False
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def set_gumbel_temperature(self, temperature: int): def set_gumbel_temperature(self, temperature: int):
""" """
...@@ -1328,7 +1330,8 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): ...@@ -1328,7 +1330,8 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
) )
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self): def freeze_feature_extractor(self):
""" """
...@@ -1436,7 +1439,8 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): ...@@ -1436,7 +1439,8 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self): def freeze_feature_extractor(self):
""" """
......
...@@ -701,7 +701,8 @@ class VisualBertModel(VisualBertPreTrainedModel): ...@@ -701,7 +701,8 @@ class VisualBertModel(VisualBertPreTrainedModel):
if self.bypass_transformer: if self.bypass_transformer:
self.additional_layer = VisualBertLayer(config) self.additional_layer = VisualBertLayer(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
...@@ -877,7 +878,8 @@ class VisualBertForPreTraining(VisualBertPreTrainedModel): ...@@ -877,7 +878,8 @@ class VisualBertForPreTraining(VisualBertPreTrainedModel):
self.visual_bert = VisualBertModel(config) self.visual_bert = VisualBertModel(config)
self.cls = VisualBertPreTrainingHeads(config) self.cls = VisualBertPreTrainingHeads(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
...@@ -1021,7 +1023,8 @@ class VisualBertForMultipleChoice(VisualBertPreTrainedModel): ...@@ -1021,7 +1023,8 @@ class VisualBertForMultipleChoice(VisualBertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.cls = nn.Linear(config.hidden_size, 1) self.cls = nn.Linear(config.hidden_size, 1)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward( @add_start_docstrings_to_model_forward(
VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
...@@ -1170,7 +1173,8 @@ class VisualBertForQuestionAnswering(VisualBertPreTrainedModel): ...@@ -1170,7 +1173,8 @@ class VisualBertForQuestionAnswering(VisualBertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.cls = nn.Linear(config.hidden_size, config.num_labels) self.cls = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
...@@ -1292,7 +1296,8 @@ class VisualBertForVisualReasoning(VisualBertPreTrainedModel): ...@@ -1292,7 +1296,8 @@ class VisualBertForVisualReasoning(VisualBertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.cls = nn.Linear(config.hidden_size, config.num_labels) # 2 self.cls = nn.Linear(config.hidden_size, config.num_labels) # 2
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
...@@ -1448,7 +1453,8 @@ class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel): ...@@ -1448,7 +1453,8 @@ class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
self.cls = VisualBertPreTrainingHeads(config) self.cls = VisualBertPreTrainingHeads(config)
self.attention = VisualBertRegionToPhraseAttention(config) self.attention = VisualBertRegionToPhraseAttention(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
......
...@@ -487,7 +487,8 @@ class ViTModel(ViTPreTrainedModel): ...@@ -487,7 +487,8 @@ class ViTModel(ViTPreTrainedModel):
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = ViTPooler(config) if add_pooling_layer else None self.pooler = ViTPooler(config) if add_pooling_layer else None
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.patch_embeddings return self.embeddings.patch_embeddings
...@@ -603,7 +604,8 @@ class ViTForImageClassification(ViTPreTrainedModel): ...@@ -603,7 +604,8 @@ class ViTForImageClassification(ViTPreTrainedModel):
# Classifier head # Classifier head
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
......
...@@ -1152,7 +1152,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): ...@@ -1152,7 +1152,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def _mask_hidden_states( def _mask_hidden_states(
self, self,
...@@ -1269,7 +1270,8 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): ...@@ -1269,7 +1270,8 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config) self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
# make sure that project_hid & project_q are initialized like normal linear layers # make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
...@@ -1480,7 +1482,8 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): ...@@ -1480,7 +1482,8 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
self.dropout = nn.Dropout(config.final_dropout) self.dropout = nn.Dropout(config.final_dropout)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Wav2Vec2BaseModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Wav2Vec2BaseModelOutput, config_class=_CONFIG_FOR_DOC)
...@@ -1563,7 +1566,8 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): ...@@ -1563,7 +1566,8 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
) )
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self): def freeze_feature_extractor(self):
""" """
...@@ -1670,7 +1674,8 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel): ...@@ -1670,7 +1674,8 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self): def freeze_feature_extractor(self):
""" """
......
...@@ -469,7 +469,8 @@ class XLMModel(XLMPreTrainedModel): ...@@ -469,7 +469,8 @@ class XLMModel(XLMPreTrainedModel):
if self.attentions[int(layer)].n_heads == config.n_heads: if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))}) self.prune_heads({int(layer): list(map(int, heads))})
self.init_weights() # Initialize weights and apply final processing
self.post_init()
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def get_input_embeddings(self): def get_input_embeddings(self):
...@@ -687,7 +688,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -687,7 +688,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self.transformer = XLMModel(config) self.transformer = XLMModel(config)
self.pred_layer = XLMPredLayer(config) self.pred_layer = XLMPredLayer(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self): def get_output_embeddings(self):
return self.pred_layer.proj return self.pred_layer.proj
...@@ -785,7 +787,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ...@@ -785,7 +787,8 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
self.transformer = XLMModel(config) self.transformer = XLMModel(config)
self.sequence_summary = SequenceSummary(config) self.sequence_summary = SequenceSummary(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -885,7 +888,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel): ...@@ -885,7 +888,8 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
self.transformer = XLMModel(config) self.transformer = XLMModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -989,7 +993,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel): ...@@ -989,7 +993,8 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self.transformer = XLMModel(config) self.transformer = XLMModel(config)
self.qa_outputs = SQuADHead(config) self.qa_outputs = SQuADHead(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=XLMForQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=XLMForQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
...@@ -1108,7 +1113,8 @@ class XLMForTokenClassification(XLMPreTrainedModel): ...@@ -1108,7 +1113,8 @@ class XLMForTokenClassification(XLMPreTrainedModel):
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -1201,7 +1207,8 @@ class XLMForMultipleChoice(XLMPreTrainedModel): ...@@ -1201,7 +1207,8 @@ class XLMForMultipleChoice(XLMPreTrainedModel):
self.sequence_summary = SequenceSummary(config) self.sequence_summary = SequenceSummary(config)
self.logits_proj = nn.Linear(config.num_labels, 1) self.logits_proj = nn.Linear(config.num_labels, 1)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -955,7 +955,8 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -955,7 +955,8 @@ class XLNetModel(XLNetPreTrainedModel):
self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)]) self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.word_embedding return self.word_embedding
...@@ -1311,7 +1312,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1311,7 +1312,8 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self.transformer = XLNetModel(config) self.transformer = XLNetModel(config)
self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True) self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_loss return self.lm_loss
...@@ -1493,7 +1495,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1493,7 +1495,8 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.sequence_summary = SequenceSummary(config) self.sequence_summary = SequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, config.num_labels) self.logits_proj = nn.Linear(config.d_model, config.num_labels)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -1600,7 +1603,8 @@ class XLNetForTokenClassification(XLNetPreTrainedModel): ...@@ -1600,7 +1603,8 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
self.transformer = XLNetModel(config) self.transformer = XLNetModel(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -1697,7 +1701,8 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1697,7 +1701,8 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
self.sequence_summary = SequenceSummary(config) self.sequence_summary = SequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, 1) self.logits_proj = nn.Linear(config.d_model, 1)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -1800,7 +1805,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1800,7 +1805,8 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
self.transformer = XLNetModel(config) self.transformer = XLNetModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -1913,7 +1919,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1913,7 +1919,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self.end_logits = PoolerEndLogits(config) self.end_logits = PoolerEndLogits(config)
self.answer_class = PoolerAnswerClass(config) self.answer_class = PoolerAnswerClass(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(XLNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=XLNetForQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=XLNetForQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
......
...@@ -777,7 +777,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna ...@@ -777,7 +777,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
self.embeddings = {{cookiecutter.camelcase_modelname}}Embeddings(config) self.embeddings = {{cookiecutter.camelcase_modelname}}Embeddings(config)
self.encoder = {{cookiecutter.camelcase_modelname}}Encoder(config) self.encoder = {{cookiecutter.camelcase_modelname}}Encoder(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
...@@ -943,7 +944,8 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m ...@@ -943,7 +944,8 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config) self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config) self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
...@@ -1046,7 +1048,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m ...@@ -1046,7 +1048,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config) self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config) self.cls = {{cookiecutter.camelcase_modelname}}OnlyMLMHead(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
...@@ -1217,7 +1220,8 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt ...@@ -1217,7 +1220,8 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config) self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.classifier = {{cookiecutter.camelcase_modelname}}ClassificationHead(config) self.classifier = {{cookiecutter.camelcase_modelname}}ClassificationHead(config)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -1309,7 +1313,8 @@ class {{cookiecutter.camelcase_modelname}}ForMultipleChoice({{cookiecutter.camel ...@@ -1309,7 +1313,8 @@ class {{cookiecutter.camelcase_modelname}}ForMultipleChoice({{cookiecutter.camel
self.sequence_summary = SequenceSummary(config) self.sequence_summary = SequenceSummary(config)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -1399,7 +1404,8 @@ class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter. ...@@ -1399,7 +1404,8 @@ class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter.
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -1486,7 +1492,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca ...@@ -1486,7 +1492,8 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config) self.{{cookiecutter.lowercase_modelname}} = {{cookiecutter.camelcase_modelname}}Model(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -2224,8 +2231,9 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model ...@@ -2224,8 +2231,9 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.init_weights()
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward( def forward(
self, self,
...@@ -2388,8 +2396,9 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2388,8 +2396,9 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}DecoderLayer(config) for _ in range(config.decoder_layers)]) self.layers = nn.ModuleList([{{cookiecutter.camelcase_modelname}}DecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model) self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
...@@ -2640,7 +2649,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna ...@@ -2640,7 +2649,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
self.encoder = {{cookiecutter.camelcase_modelname}}Encoder(config, self.shared) self.encoder = {{cookiecutter.camelcase_modelname}}Encoder(config, self.shared)
self.decoder = {{cookiecutter.camelcase_modelname}}Decoder(config, self.shared) self.decoder = {{cookiecutter.camelcase_modelname}}Decoder(config, self.shared)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.shared return self.shared
...@@ -2755,7 +2765,8 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte ...@@ -2755,7 +2765,8 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_encoder(self): def get_encoder(self):
return self.model.get_encoder() return self.model.get_encoder()
...@@ -3170,7 +3181,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m ...@@ -3170,7 +3181,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights() # Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self):
return self.model.decoder.embed_tokens return self.model.decoder.embed_tokens
......
...@@ -222,14 +222,6 @@ class ModelTesterMixin: ...@@ -222,14 +222,6 @@ class ModelTesterMixin:
config.gradient_checkpointing = True config.gradient_checkpointing = True
model = model_class(config) model = model_class(config)
# Model does not have gradient checkpointing activated yet, it will be done at the first forward.
self.assertFalse(model.is_gradient_checkpointing)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
_ = model(**inputs)
# Model has gradient checkpointing activated after the first forward.
self.assertTrue(model.is_gradient_checkpointing) self.assertTrue(model.is_gradient_checkpointing)
def test_gradient_checkpointing_enable_disable(self): def test_gradient_checkpointing_enable_disable(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment