Unverified Commit d51ca324 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

fix tests (#19670)

parent 344e2664
......@@ -763,7 +763,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(1)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(input_features, num_beams=5)
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
transcript = processor.tokenizer.batch_decode(generated_ids)[0]
EXPECTED_TRANSCRIPT = (
......@@ -781,7 +781,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(1)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(input_features, num_beams=5)
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
transcript = processor.tokenizer.decode(generated_ids[0])
EXPECTED_TRANSCRIPT = (
......@@ -801,8 +801,8 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
xla_generate = tf.function(model.generate, jit_compile=True)
generated_ids = model.generate(input_features, num_beams=5)
generated_ids_xla = xla_generate(input_features, num_beams=5)
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
generated_ids_xla = xla_generate(input_features, num_beams=5, max_length=20)
transcript = processor.tokenizer.decode(generated_ids[0])
transcript_xla = processor.tokenizer.decode(generated_ids_xla[0])
......@@ -824,10 +824,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(
input_features,
do_sample=False,
)
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
......@@ -845,7 +842,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False)
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
......@@ -855,6 +852,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
generated_ids = model.generate(
input_features,
do_sample=False,
max_length=20,
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
......@@ -862,7 +860,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
generated_ids = model.generate(input_features, do_sample=False)
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
......@@ -876,7 +874,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(input_features)
generated_ids = model.generate(input_features, max_length=20)
# fmt: off
EXPECTED_LOGITS = tf.convert_to_tensor(
......@@ -893,7 +891,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
# fmt: off
EXPECTED_TRANSCRIPT = [
' Mr. Quilter is the apostle of the middle classes, and we are glad to',
' Mr. Quilter is the apostle of the middle classes and we are glad to',
" Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
......@@ -911,7 +909,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
generated_ids = model.generate(input_features)
generated_ids = model.generate(input_features, max_length=20)
# fmt: off
EXPECTED_LOGITS = tf.convert_to_tensor(
......@@ -950,8 +948,8 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
xla_generate = tf.function(model.generate, jit_compile=True)
generated_ids = model.generate(input_features)
generated_ids_xla = xla_generate(input_features)
generated_ids = model.generate(input_features, max_length=20)
generated_ids_xla = xla_generate(input_features, max_length=20)
# fmt: off
EXPECTED_LOGITS = tf.convert_to_tensor(
......
......@@ -895,7 +895,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
torch_device
)
generated_ids = model.generate(input_features, num_beams=5)
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
transcript = processor.tokenizer.batch_decode(generated_ids)[0]
EXPECTED_TRANSCRIPT = (
......@@ -918,7 +918,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
torch_device
)
generated_ids = model.generate(input_features, num_beams=5)
generated_ids = model.generate(input_features, num_beams=5, max_length=20)
transcript = processor.tokenizer.decode(generated_ids[0])
EXPECTED_TRANSCRIPT = (
......@@ -944,6 +944,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
generated_ids = model.generate(
input_features,
do_sample=False,
max_length=20,
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
......@@ -966,7 +967,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False)
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
......@@ -976,6 +977,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
generated_ids = model.generate(
input_features,
do_sample=False,
max_length=20,
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
......@@ -983,7 +985,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
generated_ids = model.generate(input_features, do_sample=False)
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
......@@ -997,7 +999,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features
generated_ids = model.generate(input_features)
generated_ids = model.generate(input_features, max_length=20)
# fmt: off
EXPECTED_LOGITS = torch.tensor(
......@@ -1036,7 +1038,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)
generated_ids = model.generate(input_features).to("cpu")
generated_ids = model.generate(input_features, max_length=20).to("cpu")
# fmt: off
EXPECTED_LOGITS = torch.tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment