Unverified Commit d83b0e0c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a post init method to all models (#14431)

* Add a post init method to all models

* Fix tests

* Fix last tests

* Fix templates

* Add comment

* Forgot to save
parent 08816de1
......@@ -799,7 +799,8 @@ class MobileBertModel(MobileBertPreTrainedModel):
self.pooler = MobileBertPooler(config) if add_pooling_layer else None
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
......@@ -907,7 +908,8 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
self.mobilebert = MobileBertModel(config)
self.cls = MobileBertPreTrainingHeads(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
......@@ -1015,7 +1017,8 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
self.cls = MobileBertOnlyMLMHead(config)
self.config = config
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
......@@ -1111,7 +1114,8 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
self.mobilebert = MobileBertModel(config)
self.cls = MobileBertOnlyNSPHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1218,7 +1222,8 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1318,7 +1323,8 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1421,7 +1427,8 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(
MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
......@@ -1522,7 +1529,8 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -493,7 +493,8 @@ class MPNetModel(MPNetPreTrainedModel):
self.encoder = MPNetEncoder(config)
self.pooler = MPNetPooler(config) if add_pooling_layer else None
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
......@@ -583,7 +584,8 @@ class MPNetForMaskedLM(MPNetPreTrainedModel):
self.mpnet = MPNetModel(config, add_pooling_layer=False)
self.lm_head = MPNetLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head.decoder
......@@ -691,7 +693,8 @@ class MPNetForSequenceClassification(MPNetPreTrainedModel):
self.mpnet = MPNetModel(config, add_pooling_layer=False)
self.classifier = MPNetClassificationHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -785,7 +788,8 @@ class MPNetForMultipleChoice(MPNetPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
......@@ -877,7 +881,8 @@ class MPNetForTokenClassification(MPNetPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -985,7 +990,8 @@ class MPNetForQuestionAnswering(MPNetPreTrainedModel):
self.mpnet = MPNetModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -414,7 +414,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)])
self.register_buffer("position_ids", torch.arange(config.n_positions))
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.tokens_embed
......@@ -540,7 +541,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.transformer = OpenAIGPTModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
......@@ -629,7 +631,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.multiple_choice_head = SequenceSummary(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
......@@ -750,7 +753,8 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
self.transformer = OpenAIGPTModel(config)
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......
......@@ -658,8 +658,9 @@ class PegasusEncoder(PegasusPreTrainedModel):
self.layers = nn.ModuleList([PegasusEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
......@@ -853,8 +854,9 @@ class PegasusDecoder(PegasusPreTrainedModel):
self.layers = nn.ModuleList([PegasusDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
......@@ -1142,7 +1144,8 @@ class PegasusModel(PegasusPreTrainedModel):
self.encoder = PegasusEncoder(config, self.shared)
self.decoder = PegasusDecoder(config, self.shared)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.shared
......@@ -1293,7 +1296,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
......@@ -1490,7 +1494,8 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
......
......@@ -1266,8 +1266,9 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])
self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.word_embeddings
......@@ -1411,8 +1412,9 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
self.layers = nn.ModuleList([ProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)])
self.embeddings_layer_norm = LayerNorm(config.hidden_size)
self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.word_embeddings
......@@ -1765,7 +1767,8 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
decoder_config.is_encoder_decoder = False
self.decoder = ProphetNetDecoder(decoder_config, self.word_embeddings)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.word_embeddings
......@@ -1882,7 +1885,8 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
......@@ -2092,7 +2096,8 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.prophetnet.decoder.word_embeddings
......
......@@ -1974,7 +1974,8 @@ class ReformerModel(ReformerPreTrainedModel):
self.embeddings = ReformerEmbeddings(config)
self.encoder = ReformerEncoder(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
......@@ -2188,7 +2189,8 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
self.reformer = ReformerModel(config)
self.lm_head = ReformerOnlyLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head.decoder
......@@ -2303,7 +2305,8 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
self.reformer = ReformerModel(config)
self.lm_head = ReformerOnlyLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head.decoder
......@@ -2390,7 +2393,8 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
if config.is_decoder is True:
logger.warning("You might want to disable causal masking for sequence classification")
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......@@ -2508,7 +2512,8 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
# 2 * config.hidden_size because we use reversible residual layers
self.qa_outputs = nn.Linear(2 * config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......
......@@ -765,7 +765,8 @@ class RemBertModel(RemBertPreTrainedModel):
self.pooler = RemBertPooler(config) if add_pooling_layer else None
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
......@@ -925,7 +926,8 @@ class RemBertForMaskedLM(RemBertPreTrainedModel):
self.rembert = RemBertModel(config, add_pooling_layer=False)
self.cls = RemBertOnlyMLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
......@@ -1027,7 +1029,8 @@ class RemBertForCausalLM(RemBertPreTrainedModel):
self.rembert = RemBertModel(config, add_pooling_layer=False)
self.cls = RemBertOnlyMLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
......@@ -1173,7 +1176,8 @@ class RemBertForSequenceClassification(RemBertPreTrainedModel):
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1269,7 +1273,8 @@ class RemBertForMultipleChoice(RemBertPreTrainedModel):
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
......@@ -1361,7 +1366,8 @@ class RemBertForTokenClassification(RemBertPreTrainedModel):
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1449,7 +1455,8 @@ class RemBertForQuestionAnswering(RemBertPreTrainedModel):
self.rembert = RemBertModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(REMBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -99,7 +99,8 @@ class RetriBertModel(RetriBertPreTrainedModel):
self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def embed_sentences_checkpointed(
self,
......
......@@ -723,7 +723,8 @@ class RobertaModel(RobertaPreTrainedModel):
self.pooler = RobertaPooler(config) if add_pooling_layer else None
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
......@@ -897,7 +898,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head.decoder
......@@ -1050,7 +1052,8 @@ class RobertaForMaskedLM(RobertaPreTrainedModel):
# The LM head weights require special treatment only when they are tied with the word embeddings
self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head.decoder
......@@ -1169,7 +1172,8 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1265,7 +1269,8 @@ class RobertaForMultipleChoice(RobertaPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
......@@ -1362,7 +1367,8 @@ class RobertaForTokenClassification(RobertaPreTrainedModel):
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1474,7 +1480,8 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
self.roberta = RobertaModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -817,7 +817,8 @@ class RoFormerModel(RoFormerPreTrainedModel):
self.encoder = RoFormerEncoder(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
......@@ -973,7 +974,8 @@ class RoFormerForMaskedLM(RoFormerPreTrainedModel):
self.roformer = RoFormerModel(config)
self.cls = RoFormerOnlyMLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
......@@ -1073,7 +1075,8 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel):
self.roformer = RoFormerModel(config)
self.cls = RoFormerOnlyMLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
......@@ -1238,7 +1241,8 @@ class RoFormerForSequenceClassification(RoFormerPreTrainedModel):
self.roformer = RoFormerModel(config)
self.classifier = RoFormerClassificationHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1330,7 +1334,8 @@ class RoFormerForMultipleChoice(RoFormerPreTrainedModel):
self.sequence_summary = SequenceSummary(config)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(
ROFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
......@@ -1422,7 +1427,8 @@ class RoFormerForTokenClassification(RoFormerPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1509,7 +1515,8 @@ class RoFormerForQuestionAnswering(RoFormerPreTrainedModel):
self.roformer = RoFormerModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -467,7 +467,8 @@ class SegformerModel(SegformerPreTrainedModel):
# hierarchical Transformer encoder
self.encoder = SegformerEncoder(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def _prune_heads(self, heads_to_prune):
"""
......@@ -541,7 +542,8 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
# Classifier head
self.classifier = nn.Linear(config.hidden_sizes[-1], config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
......@@ -696,7 +698,8 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
self.segformer = SegformerModel(config)
self.decode_head = SegformerDecodeHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
......
......@@ -798,7 +798,8 @@ class SEWModel(SEWPreTrainedModel):
self.encoder = SEWEncoder(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
......@@ -924,7 +925,8 @@ class SEWForCTC(SEWPreTrainedModel):
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
......@@ -1032,7 +1034,8 @@ class SEWForSequenceClassification(SEWPreTrainedModel):
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
......
......@@ -1329,7 +1329,8 @@ class SEWDModel(SEWDPreTrainedModel):
self.encoder = SEWDEncoder(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
def _mask_hidden_states(
......@@ -1455,7 +1456,8 @@ class SEWDForCTC(SEWDPreTrainedModel):
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
......@@ -1563,7 +1565,8 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel):
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
......
......@@ -723,8 +723,9 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
self.layers = nn.ModuleList([Speech2TextEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
......@@ -876,8 +877,9 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
......@@ -1130,7 +1132,8 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
self.encoder = Speech2TextEncoder(config)
self.decoder = Speech2TextDecoder(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.decoder.embed_tokens
......@@ -1253,7 +1256,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
self.model = Speech2TextModel(config)
self.lm_head = nn.Linear(config.d_model, self.config.vocab_size, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
......
......@@ -476,8 +476,9 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
self.layers = nn.ModuleList([Speech2Text2DecoderLayer(config) for _ in range(config.decoder_layers)])
self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
......@@ -751,7 +752,8 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
......
......@@ -619,7 +619,8 @@ class SplinterModel(SplinterPreTrainedModel):
self.embeddings = SplinterEmbeddings(config)
self.encoder = SplinterEncoder(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
......@@ -834,7 +835,8 @@ class SplinterForQuestionAnswering(SplinterPreTrainedModel):
self.splinter_qass = QuestionAwareSpanSelectionHead(config)
self.question_token_id = config.question_token_id
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SPLINTER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -553,7 +553,8 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
self.encoder = SqueezeBertEncoder(config)
self.pooler = SqueezeBertPooler(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
......@@ -654,7 +655,8 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
self.transformer = SqueezeBertModel(config)
self.cls = SqueezeBertOnlyMLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
......@@ -739,7 +741,8 @@ class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -836,7 +839,8 @@ class SqueezeBertForMultipleChoice(SqueezeBertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(
SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
......@@ -930,7 +934,8 @@ class SqueezeBertForTokenClassification(SqueezeBertPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1017,7 +1022,8 @@ class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel):
self.transformer = SqueezeBertModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -814,7 +814,8 @@ class T5Stack(T5PreTrainedModel):
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
......@@ -1267,7 +1268,8 @@ class T5Model(T5PreTrainedModel):
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(decoder_config, self.shared)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
......@@ -1457,7 +1459,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
......@@ -1731,7 +1734,8 @@ class T5EncoderModel(T5PreTrainedModel):
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
......
......@@ -877,7 +877,8 @@ class TapasModel(TapasPreTrainedModel):
self.pooler = TapasPooler(config) if add_pooling_layer else None
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
......@@ -1016,7 +1017,8 @@ class TapasForMaskedLM(TapasPreTrainedModel):
self.tapas = TapasModel(config, add_pooling_layer=False)
self.cls = TapasOnlyMLMHead(config)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.cls.predictions.decoder
......@@ -1146,7 +1148,8 @@ class TapasForQuestionAnswering(TapasPreTrainedModel):
if config.num_aggregation_labels > 0:
self.aggregation_classifier = nn.Linear(config.hidden_size, config.num_aggregation_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TableQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1464,7 +1467,8 @@ class TapasForSequenceClassification(TapasPreTrainedModel):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
......
......@@ -819,7 +819,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
else: # learnable embeddings and absolute embeddings
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.word_emb
......@@ -1021,7 +1022,8 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
"""
......@@ -1170,7 +1172,8 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
self.num_labels = config.num_labels
self.transformer = TransfoXLModel(config)
self.score = nn.Linear(config.d_embed, self.num_labels, bias=False)
self.init_weights()
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment