"...git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "42b8d462609f19e5336cc25721a76d525f97b448"
Unverified Commit 11b2e45c authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[WHISPER] Update modeling tests (#20162)



* Update modeling tests

* update tokenization test

* typo

* nit

* fix expected attention outputs

* Apply suggestions from code review
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests from review
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>

* remove problematics kwargs passed to the padding function
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent f60eec40
...@@ -307,7 +307,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -307,7 +307,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
max_length=max_length if max_length else self.n_samples, max_length=max_length if max_length else self.n_samples,
truncation=truncation, truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of, pad_to_multiple_of=pad_to_multiple_of,
**kwargs,
) )
# make sure list is in array format # make sure list is in array format
input_features = padded_inputs.get("input_features").transpose(2, 0, 1) input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
......
...@@ -650,13 +650,15 @@ def _test_large_logits_librispeech(in_queue, out_queue, timeout): ...@@ -650,13 +650,15 @@ def _test_large_logits_librispeech(in_queue, out_queue, timeout):
input_speech = _load_datasamples(1) input_speech = _load_datasamples(1)
processor = WhisperProcessor.from_pretrained("openai/whisper-large") processor = WhisperProcessor.from_pretrained("openai/whisper-large")
processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="tf") processed_inputs = processor(
audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="tf"
)
input_features = processed_inputs.input_features input_features = processed_inputs.input_features
labels = processed_inputs.labels decoder_input_ids = processed_inputs.labels
logits = model( logits = model(
input_features, input_features,
decoder_input_ids=labels, decoder_input_ids=decoder_input_ids,
output_hidden_states=False, output_hidden_states=False,
output_attentions=False, output_attentions=False,
use_cache=False, use_cache=False,
......
...@@ -853,13 +853,15 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -853,13 +853,15 @@ class WhisperModelIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
processor = WhisperProcessor.from_pretrained("openai/whisper-large") processor = WhisperProcessor.from_pretrained("openai/whisper-large")
processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="pt") processed_inputs = processor(
audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="pt"
)
input_features = processed_inputs.input_features.to(torch_device) input_features = processed_inputs.input_features.to(torch_device)
labels = processed_inputs.labels.to(torch_device) decoder_input_ids = processed_inputs.labels.to(torch_device)
logits = model( logits = model(
input_features, input_features,
decoder_input_ids=labels, decoder_input_ids=decoder_input_ids,
output_hidden_states=False, output_hidden_states=False,
output_attentions=False, output_attentions=False,
use_cache=False, use_cache=False,
......
...@@ -96,7 +96,7 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -96,7 +96,7 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
@slow @slow
def test_tokenizer_integration(self): def test_tokenizer_integration(self):
# fmt: off # fmt: off
expected_encoding = {'input_ids': [[41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276, 12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276, 7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363, 4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13], [13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13], [464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # noqa: E501 expected_encoding = {'input_ids': [[50257, 50362, 41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276, 12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276, 7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363, 4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13, 50256], [50257, 50362, 13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13, 50256], [50257, 50362, 464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13, 50256]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # noqa: E501
# fmt: on # fmt: on
self.tokenizer_integration_test_util( self.tokenizer_integration_test_util(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment