Unverified Commit d51ca324 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

fix tests (#19670)

parent 344e2664
...@@ -763,7 +763,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -763,7 +763,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(input_features, num_beams=5) generated_ids = model.generate(input_features, num_beams=5, max_length=20)
transcript = processor.tokenizer.batch_decode(generated_ids)[0] transcript = processor.tokenizer.batch_decode(generated_ids)[0]
EXPECTED_TRANSCRIPT = ( EXPECTED_TRANSCRIPT = (
...@@ -781,7 +781,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -781,7 +781,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(input_features, num_beams=5) generated_ids = model.generate(input_features, num_beams=5, max_length=20)
transcript = processor.tokenizer.decode(generated_ids[0]) transcript = processor.tokenizer.decode(generated_ids[0])
EXPECTED_TRANSCRIPT = ( EXPECTED_TRANSCRIPT = (
...@@ -801,8 +801,8 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -801,8 +801,8 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
xla_generate = tf.function(model.generate, jit_compile=True) xla_generate = tf.function(model.generate, jit_compile=True)
generated_ids = model.generate(input_features, num_beams=5) generated_ids = model.generate(input_features, num_beams=5, max_length=20)
generated_ids_xla = xla_generate(input_features, num_beams=5) generated_ids_xla = xla_generate(input_features, num_beams=5, max_length=20)
transcript = processor.tokenizer.decode(generated_ids[0]) transcript = processor.tokenizer.decode(generated_ids[0])
transcript_xla = processor.tokenizer.decode(generated_ids_xla[0]) transcript_xla = processor.tokenizer.decode(generated_ids_xla[0])
...@@ -824,10 +824,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -824,10 +824,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate( generated_ids = model.generate(input_features, do_sample=False, max_length=20)
input_features,
do_sample=False,
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad" EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
...@@ -845,7 +842,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -845,7 +842,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False) generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
...@@ -855,6 +852,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -855,6 +852,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
generated_ids = model.generate( generated_ids = model.generate(
input_features, input_features,
do_sample=False, do_sample=False,
max_length=20,
) )
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
...@@ -862,7 +860,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -862,7 +860,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
self.assertEqual(transcript, EXPECTED_TRANSCRIPT) self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
generated_ids = model.generate(input_features, do_sample=False) generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san" EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
...@@ -876,7 +874,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -876,7 +874,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(4) input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(input_features) generated_ids = model.generate(input_features, max_length=20)
# fmt: off # fmt: off
EXPECTED_LOGITS = tf.convert_to_tensor( EXPECTED_LOGITS = tf.convert_to_tensor(
...@@ -893,7 +891,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -893,7 +891,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
# fmt: off # fmt: off
EXPECTED_TRANSCRIPT = [ EXPECTED_TRANSCRIPT = [
' Mr. Quilter is the apostle of the middle classes, and we are glad to', ' Mr. Quilter is the apostle of the middle classes and we are glad to',
" Nor is Mr. Quilter's manner less interesting than his matter.", " Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef", " He tells us that at this festive season of the year, with Christmas and roast beef",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all," " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
...@@ -911,7 +909,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -911,7 +909,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(4) input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(input_features) generated_ids = model.generate(input_features, max_length=20)
# fmt: off # fmt: off
EXPECTED_LOGITS = tf.convert_to_tensor( EXPECTED_LOGITS = tf.convert_to_tensor(
...@@ -950,8 +948,8 @@ class TFWhisperModelIntegrationTests(unittest.TestCase): ...@@ -950,8 +948,8 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
xla_generate = tf.function(model.generate, jit_compile=True) xla_generate = tf.function(model.generate, jit_compile=True)
generated_ids = model.generate(input_features) generated_ids = model.generate(input_features, max_length=20)
generated_ids_xla = xla_generate(input_features) generated_ids_xla = xla_generate(input_features, max_length=20)
# fmt: off # fmt: off
EXPECTED_LOGITS = tf.convert_to_tensor( EXPECTED_LOGITS = tf.convert_to_tensor(
......
...@@ -895,7 +895,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -895,7 +895,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
torch_device torch_device
) )
generated_ids = model.generate(input_features, num_beams=5) generated_ids = model.generate(input_features, num_beams=5, max_length=20)
transcript = processor.tokenizer.batch_decode(generated_ids)[0] transcript = processor.tokenizer.batch_decode(generated_ids)[0]
EXPECTED_TRANSCRIPT = ( EXPECTED_TRANSCRIPT = (
...@@ -918,7 +918,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -918,7 +918,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
torch_device torch_device
) )
generated_ids = model.generate(input_features, num_beams=5) generated_ids = model.generate(input_features, num_beams=5, max_length=20)
transcript = processor.tokenizer.decode(generated_ids[0]) transcript = processor.tokenizer.decode(generated_ids[0])
EXPECTED_TRANSCRIPT = ( EXPECTED_TRANSCRIPT = (
...@@ -944,6 +944,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -944,6 +944,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
generated_ids = model.generate( generated_ids = model.generate(
input_features, input_features,
do_sample=False, do_sample=False,
max_length=20,
) )
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
...@@ -966,7 +967,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -966,7 +967,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
) )
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False) generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
...@@ -976,6 +977,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -976,6 +977,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
generated_ids = model.generate( generated_ids = model.generate(
input_features, input_features,
do_sample=False, do_sample=False,
max_length=20,
) )
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
...@@ -983,7 +985,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -983,7 +985,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
self.assertEqual(transcript, EXPECTED_TRANSCRIPT) self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
generated_ids = model.generate(input_features, do_sample=False) generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san" EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
...@@ -997,7 +999,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -997,7 +999,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(4) input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features
generated_ids = model.generate(input_features) generated_ids = model.generate(input_features, max_length=20)
# fmt: off # fmt: off
EXPECTED_LOGITS = torch.tensor( EXPECTED_LOGITS = torch.tensor(
...@@ -1036,7 +1038,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -1036,7 +1038,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device torch_device
) )
generated_ids = model.generate(input_features).to("cpu") generated_ids = model.generate(input_features, max_length=20).to("cpu")
# fmt: off # fmt: off
EXPECTED_LOGITS = torch.tensor( EXPECTED_LOGITS = torch.tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment