Unverified Commit 54a2361a authored by JB (Don)'s avatar JB (Don) Committed by GitHub
Browse files

Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True (#29024)

* Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True

* Testing for the non-safe-tensors case, since the default is safe-tensors already

* Running fixup/fix-copies

* Adding accelerate annotations to tests
parent ce47582d
......@@ -681,6 +681,9 @@ class QDQBertLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -1022,6 +1025,7 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
......@@ -1188,6 +1192,7 @@ class QDQBertForMaskedLM(QDQBertPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -797,6 +797,9 @@ class RealmLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -1391,6 +1394,7 @@ class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(
REALM_INPUTS_DOCSTRING.format("batch_size, num_candidates, sequence_length")
......
......@@ -1768,9 +1768,13 @@ class ReformerOnlyLMHead(nn.Module):
hidden_states = self.decoder(hidden_states)
return hidden_states
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class ReformerPreTrainedModel(PreTrainedModel):
......@@ -2208,6 +2212,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......@@ -2328,6 +2333,7 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
......
......@@ -749,6 +749,9 @@ class RoCBertLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -1094,6 +1097,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1286,6 +1290,7 @@ class RoCBertForMaskedLM(RoCBertPreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def forward(
......@@ -1423,6 +1428,7 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
......
......@@ -657,6 +657,9 @@ class RoFormerLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self) -> None:
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -954,6 +957,7 @@ class RoFormerForMaskedLM(RoFormerPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1053,6 +1057,7 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
......
......@@ -400,6 +400,9 @@ class SqueezeBertLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self) -> None:
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -658,6 +661,7 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -699,6 +699,9 @@ class TapasLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -978,6 +981,7 @@ class TapasForMaskedLM(TapasPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
......
......@@ -894,6 +894,7 @@ class ViltForMaskedLM(ViltPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.mlm_score.decoder = new_embeddings
self.mlm_score.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1040,6 +1041,9 @@ class ViltMLMHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, x):
x = self.transform(x)
x = self.decoder(x)
......
......@@ -489,6 +489,9 @@ class VisualBertLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -869,6 +872,7 @@ class VisualBertForPreTraining(VisualBertPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
......
......@@ -852,6 +852,7 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(XLM_ROBERTA_XL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
......@@ -1011,6 +1012,7 @@ class XLMRobertaXLForMaskedLM(XLMRobertaXLPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(XLM_ROBERTA_XL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......@@ -1099,9 +1101,13 @@ class XLMRobertaXLLMHead(nn.Module):
return x
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
@add_start_docstrings(
......
......@@ -627,6 +627,9 @@ class YosoLMPredictionHead(nn.Module):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
......@@ -861,6 +864,7 @@ class YosoForMaskedLM(YosoPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias
@add_start_docstrings_to_model_forward(YOSO_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
......
......@@ -578,6 +578,18 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
def test_two_stage_training(self):
model_class = DeformableDetrForObjectDetection
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -528,6 +528,18 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
# Inspired by tests.test_modeling_common.ModelTesterMixin.test_tied_weights_keys
def test_tied_weights_keys(self):
for model_class in self.all_model_classes:
......
......@@ -325,6 +325,18 @@ class EncodecModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def test_hidden_states_output(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
def test_determinism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -766,6 +766,18 @@ class LxmertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
return tf_inputs_dict
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
@require_torch
class LxmertModelIntegrationTest(unittest.TestCase):
......
......@@ -372,6 +372,18 @@ class MarianModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
......
......@@ -1273,6 +1273,18 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
def test_tied_weights_keys(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
# override since changing `output_hidden_states` / `output_attentions` from the top-level model config won't work
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -1258,6 +1258,18 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_tied_weights_keys(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
# override since changing `output_hidden_states` / `output_attentions` from the top-level model config won't work
# Ignore copy
def test_retain_grad_hidden_states_attentions(self):
......
......@@ -356,6 +356,18 @@ class SEWModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_model_common_attributes(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
......
......@@ -460,6 +460,18 @@ class SEWDModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_feed_forward_chunking(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
@slow
def test_model_from_pretrained(self):
model = SEWDModel.from_pretrained("asapp/sew-d-tiny-100k")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment