"LiveCDGUI/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "1d6658419043c103f9c8382fec0196d0288deb55"
Unverified Commit 38a4bf79 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Encoder-decoder models: move embedding scale to nn.Module (#30410)



* move scaling to nn.Module

* let the test be here for now (need to fix)

* failing tests

* last failing models

* Revert commit 4c14817f38

* clean-up

* oops forgot

* codestyle

* raise NotImplemented when possible

* Update tests/test_modeling_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* skip tests in respective modeling files

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 9d31b32e
...@@ -247,6 +247,10 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline ...@@ -247,6 +247,10 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="Conditional DETR does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="Conditional DETR does not have a get_input_embeddings method") @unittest.skip(reason="Conditional DETR does not have a get_input_embeddings method")
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
......
...@@ -253,6 +253,10 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT ...@@ -253,6 +253,10 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="Deformable DETR does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="Deformable DETR does not have a get_input_embeddings method") @unittest.skip(reason="Deformable DETR does not have a get_input_embeddings method")
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
......
...@@ -303,6 +303,10 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -303,6 +303,10 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="DETA does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="DETA does not have a get_input_embeddings method") @unittest.skip(reason="DETA does not have a get_input_embeddings method")
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
......
...@@ -247,6 +247,10 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -247,6 +247,10 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="DETR does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="DETR does not have a get_input_embeddings method") @unittest.skip(reason="DETR does not have a get_input_embeddings method")
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
......
...@@ -321,6 +321,10 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -321,6 +321,10 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip("Input ids is required for FSMT.")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip("model weights aren't tied in FSMT.") @unittest.skip("model weights aren't tied in FSMT.")
def test_tie_model_weights(self): def test_tie_model_weights(self):
pass pass
......
...@@ -182,6 +182,14 @@ class GPTSanJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas ...@@ -182,6 +182,14 @@ class GPTSanJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
def test_model_parallelism(self): def test_model_parallelism(self):
super().test_model_parallelism() super().test_model_parallelism()
@unittest.skip(reason="Gptsan does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Gptsan does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@require_torch @require_torch
class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
...@@ -212,6 +220,14 @@ class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTes ...@@ -212,6 +220,14 @@ class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTes
def test_model_parallelism(self): def test_model_parallelism(self):
super().test_model_parallelism() super().test_model_parallelism()
@unittest.skip(reason="Gptsan does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Gptsan does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@slow @slow
def test_logits(self): def test_logits(self):
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese") model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
......
...@@ -382,6 +382,10 @@ class IBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -382,6 +382,10 @@ class IBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
model(**inputs)[0] model(**inputs)[0]
@unittest.skip("ibert overrides scaling to None if inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@require_torch @require_torch
class IBertModelIntegrationTest(unittest.TestCase): class IBertModelIntegrationTest(unittest.TestCase):
......
...@@ -180,6 +180,10 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -180,6 +180,10 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(): def test_inputs_embeds():
pass pass
@unittest.skip("input_embeds cannot be passed in without input_ids")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip("Model does not support padding right") @unittest.skip("Model does not support padding right")
def test_flash_attn_2_generate_padding_right(self): def test_flash_attn_2_generate_padding_right(self):
pass pass
......
...@@ -466,6 +466,31 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ...@@ -466,6 +466,31 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
with torch.no_grad(): with torch.no_grad():
model(**inputs)[0] model(**inputs)[0]
# override because ImageGPT main input name is `pixel_values`
# NOTE: in latest transformers this is deprecated, `input_ids` should be used. TODO
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
with torch.no_grad():
out_ids = model(**inputs)[0]
pixel_values = inputs["pixel_values"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(pixel_values)
with torch.no_grad():
out_embeds = model(**inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
def _create_and_check_torchscript(self, config, inputs_dict): def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript: if not self.test_torchscript:
return return
......
...@@ -265,6 +265,10 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -265,6 +265,10 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
lm_heads = model.get_output_embeddings() lm_heads = model.get_output_embeddings()
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear)) self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
@unittest.skip(reason="MusicGen does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
# skip as this model doesn't support all arguments tested # skip as this model doesn't support all arguments tested
def test_model_outputs_equivalence(self): def test_model_outputs_equivalence(self):
pass pass
......
...@@ -268,6 +268,10 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes ...@@ -268,6 +268,10 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
lm_heads = model.get_output_embeddings() lm_heads = model.get_output_embeddings()
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear)) self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
@unittest.skip(reason="MusicGen melody does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip("this model doesn't support all arguments tested") @unittest.skip("this model doesn't support all arguments tested")
def test_model_outputs_equivalence(self): def test_model_outputs_equivalence(self):
pass pass
......
...@@ -463,6 +463,10 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase): ...@@ -463,6 +463,10 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="SeamlessM4TSpeechEncoder doesn't have an embedding layer")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip( @unittest.skip(
reason="Expected missing keys serve when using SeamlessM4TForXXX.from_pretrained from a checkpoint saved by SeamlessM4TModel.save_pretrained." reason="Expected missing keys serve when using SeamlessM4TForXXX.from_pretrained from a checkpoint saved by SeamlessM4TModel.save_pretrained."
) )
......
...@@ -479,6 +479,10 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase) ...@@ -479,6 +479,10 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="SeamlessM4TSpeechEncoder doesn't have an embedding layer")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip( @unittest.skip(
reason="Expected missing keys serve when using SeamlessM4Tv2ForXXX.from_pretrained from a checkpoint saved by SeamlessM4Tv2Model.save_pretrained." reason="Expected missing keys serve when using SeamlessM4Tv2ForXXX.from_pretrained from a checkpoint saved by SeamlessM4Tv2Model.save_pretrained."
) )
......
...@@ -261,6 +261,10 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin ...@@ -261,6 +261,10 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="Table Transformer does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="Table Transformer does not have a get_input_embeddings method") @unittest.skip(reason="Table Transformer does not have a get_input_embeddings method")
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
......
...@@ -357,6 +357,13 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -357,6 +357,13 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_model_outputs_equivalence(self): def test_model_outputs_equivalence(self):
pass pass
@unittest.skip(
reason="""VilT samples image tokens from a multinomial distribution, resulting in not deterministic
hidden states. Cannot test equivalence on logit level"""
)
def test_inputs_embeds_matches_input_ids(self):
pass
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
......
...@@ -2767,6 +2767,51 @@ class ModelTesterMixin: ...@@ -2767,6 +2767,51 @@ class ModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
model(**inputs)[0] model(**inputs)[0]
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class.__name__ not in get_values(MODEL_MAPPING_NAMES):
continue
model = model_class(config)
model.to(torch_device)
model.eval()
model_forward_args = inspect.signature(model.forward).parameters
if "inputs_embeds" not in model_forward_args:
self.skipTest("This model doesn't use `inputs_embeds`")
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
wte = model.get_input_embeddings()
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
# some models infer position ids/attn mask differently when input ids
# by check if pad_token let's make sure no padding is in input ids
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
input_ids[input_ids == pad_token_id] = not_pad_token_id
del inputs["input_ids"]
inputs_embeds = wte(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
inputs_embeds = wte(encoder_input_ids)
decoder_inputs_embeds = wte(decoder_input_ids)
with torch.no_grad():
out_ids = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs)[0]
out_embeds = model(
inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **inputs
)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@require_torch_multi_gpu @require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self): def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment