Unverified Commit 54a2361a authored by JB (Don)'s avatar JB (Don) Committed by GitHub
Browse files

Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True (#29024)

* Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True

* Testing for the non-safe-tensors case, since the default is safe-tensors already

* Running fixup/fix-copies

* Adding accelerate annotations to tests
parent ce47582d
......@@ -877,6 +877,10 @@ class AlbertMLMHead(nn.Module):
return prediction_scores
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
......@@ -915,6 +919,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.predictions.decoder = new_embeddings
self.predictions.bias = new_embeddings.bias
def get_input_embeddings(self) -> nn.Embedding:
return self.albert.embeddings.word_embeddings
......
......@@ -778,6 +778,9 @@ class BertLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -1186,6 +1189,7 @@ class BertForPreTraining(BertPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1295,6 +1299,7 @@ class BertLMHeadModel(BertPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1448,6 +1453,7 @@ class BertForMaskedLM(BertPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -851,6 +851,10 @@ class BertGenerationOnlyLMHead(nn.Module):
return logits
def _tie_weights(self):
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
......@@ -879,6 +883,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
......
......@@ -1704,6 +1704,9 @@ class BigBirdLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -2263,6 +2266,7 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BigBirdForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
......@@ -2375,6 +2379,7 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
......@@ -2516,6 +2521,7 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -523,6 +523,9 @@ class BlipTextLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -818,6 +821,7 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
def forward(
self,
......
......@@ -1021,6 +1021,7 @@ class DebertaForMaskedLM(DebertaPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1117,6 +1118,9 @@ class DebertaLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......
......@@ -1120,6 +1120,7 @@ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1217,6 +1218,9 @@ class DebertaV2LMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......
......@@ -603,6 +603,9 @@ class ErnieLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -990,6 +993,7 @@ class ErnieForPreTraining(ErniePreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=ErnieForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1104,6 +1108,7 @@ class ErnieForCausalLM(ErniePreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1264,6 +1269,7 @@ class ErnieForMaskedLM(ErniePreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -1661,6 +1661,9 @@ class FlavaMaskedPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, x):
x = self.transform(x)
x = self.decoder(x)
......
......@@ -355,7 +355,11 @@ class FNetLMPredictionHead(nn.Module):
hidden_states = self.decoder(hidden_states)
return hidden_states
def _tie_weights(self):
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
......@@ -624,6 +628,7 @@ class FNetForPreTraining(FNetPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=FNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
......@@ -718,6 +723,7 @@ class FNetForMaskedLM(FNetPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -692,6 +692,9 @@ class FSMTDecoder(nn.Module):
self.output_projection = nn.Linear(embed_tokens_weight_shape[1], embed_tokens_weight_shape[0], bias=False)
self.output_projection.weight = self.embed_tokens.weight
def _tie_weights(self):
self.embed_tokens.weight = self.output_projection.weight
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -868,6 +868,7 @@ class IBertForMaskedLM(IBertPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(IBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -952,7 +953,11 @@ class IBertLMHead(nn.Module):
return x
def _tie_weights(self):
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
......
......@@ -594,6 +594,9 @@ class LayoutLMLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -873,6 +876,7 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
......
......@@ -316,6 +316,9 @@ class MarkupLMLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......
......@@ -657,6 +657,9 @@ class MegatronBertLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -1021,6 +1024,7 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1130,6 +1134,7 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
......@@ -1288,6 +1293,7 @@ class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -650,6 +650,9 @@ class MobileBertLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self) -> None:
self.decoder.bias = self.bias
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.transform(hidden_states)
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
......@@ -938,8 +941,9 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddigs):
self.cls.predictions.decoder = new_embeddigs
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
# resize dense output embedings at first
......@@ -1047,8 +1051,9 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddigs):
self.cls.predictions.decoder = new_embeddigs
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
# resize dense output embedings at first
......
......@@ -584,6 +584,7 @@ class MPNetForMaskedLM(MPNetPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -656,6 +657,9 @@ class MPNetLMHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
......
......@@ -809,6 +809,9 @@ class MraLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -1042,6 +1045,7 @@ class MraForMaskedLM(MraPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(MRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -674,6 +674,9 @@ class NezhaLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -1039,6 +1042,7 @@ class NezhaForPreTraining(NezhaPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=NezhaForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1147,6 +1151,7 @@ class NezhaForMaskedLM(NezhaPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -426,6 +426,9 @@ class NystromformerLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -664,6 +667,7 @@ class NystromformerForMaskedLM(NystromformerPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment