Unverified Commit 4e98d594 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

[FIX] Fix speech2test modeling tests (#29672)



* fix speech_to_test generation tests

* Add details to comment

* Update tests/models/speech_to_text/test_modeling_speech_to_text.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avatarYih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 9e4df7c4
...@@ -284,6 +284,18 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest ...@@ -284,6 +284,18 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
input_name = "input_features" input_name = "input_features"
def _get_input_ids_and_config(self, batch_size=2):
config, input_ids, attention_mask, max_length = GenerationTesterMixin._get_input_ids_and_config(self)
# `input_ids` is actually `input_features` which is a 3D tensor.
# We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an
# attention mask of the same shape as `input_ids`.
if len(attention_mask.shape) > 2:
sequence_length = input_ids.shape[1]
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device)
return config, input_ids, attention_mask, max_length
def setUp(self): def setUp(self):
self.model_tester = Speech2TextModelTester(self) self.model_tester = Speech2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Speech2TextConfig) self.config_tester = ConfigTester(self, config_class=Speech2TextConfig)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment