"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "78a2b19fc84ed55c65f4bf20a901edb7ceb73c5f"
Unverified Commit 6f3faf38 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[WHISPER] Small patch (#21307)

* add small patch

* update tests, forced decoder ids is not prioritary against generation config

* fix two new tests
parent 140c6ede
...@@ -936,18 +936,19 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): ...@@ -936,18 +936,19 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
self.no_timestamps_token_id = generate_config.no_timestamps_token_id self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1 self.timestamp_begin = generate_config.no_timestamps_token_id + 1
self.begin_index = len(generate_config.forced_decoder_ids) + 1 self.begin_index = len(generate_config.forced_decoder_ids) + 2
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id: if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
self.begin_index -= 1 self.begin_index -= 1
if generate_config.is_multilingual:
self.begin_index += 1
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
# suppress <|notimestamps|> which is handled by without_timestamps # suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf") scores[:, self.no_timestamps_token_id] = -float("inf")
if input_ids.shape[1] == self.begin_index:
if input_ids.shape[1] == self.begin_index - 1:
scores[:, :] = -float("inf")
scores[:, self.timestamp_begin] = 0 scores[:, self.timestamp_begin] = 0
return scores
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]): for k in range(input_ids.shape[0]):
......
...@@ -699,8 +699,9 @@ def _test_large_generation(in_queue, out_queue, timeout): ...@@ -699,8 +699,9 @@ def _test_large_generation(in_queue, out_queue, timeout):
input_speech = _load_datasamples(1) input_speech = _load_datasamples(1)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") generated_ids = model.generate(
generated_ids = model.generate(input_features, do_sample=False, max_length=20) input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad" EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
...@@ -728,26 +729,25 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout): ...@@ -728,26 +729,25 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
input_speech = next(iter(ds))["audio"]["array"] input_speech = next(iter(ds))["audio"]["array"]
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") generated_ids = model.generate(
generated_ids = model.generate(input_features, do_sample=False, max_length=20) input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT) unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate( generated_ids = model.generate(
input_features, input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
do_sample=False,
max_length=20,
) )
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Kimura-san called me." EXPECTED_TRANSCRIPT = " Kimura-san called me."
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT) unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") generated_ids = model.generate(
generated_ids = model.generate(input_features, do_sample=False, max_length=20) input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san" EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
...@@ -779,10 +779,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout): ...@@ -779,10 +779,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
# fmt: off # fmt: off
EXPECTED_LOGITS = tf.convert_to_tensor( EXPECTED_LOGITS = tf.convert_to_tensor(
[ [
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281], [50258, 50259, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257], [50258, 50259, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256], [50258, 50259, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11] [50258, 50259, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
] ]
) )
# fmt: on # fmt: on
...@@ -791,10 +791,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout): ...@@ -791,10 +791,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
# fmt: off # fmt: off
EXPECTED_TRANSCRIPT = [ EXPECTED_TRANSCRIPT = [
' Mr. Quilter is the apostle of the middle classes and we are glad to', " Mr. Quilter is the apostle of the middle classes and we are glad",
" Nor is Mr. Quilter's manner less interesting than his matter.", " Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef", " He tells us that at this festive season of the year, with Christmas and roast",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all," " He has grave doubts whether Sir Frederick Layton's work is really Greek after all",
] ]
# fmt: on # fmt: on
......
...@@ -945,11 +945,8 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -945,11 +945,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
torch_device torch_device
) )
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate( generated_ids = model.generate(
input_features, input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
do_sample=False,
max_length=20,
) )
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
...@@ -971,26 +968,25 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -971,26 +968,25 @@ class WhisperModelIntegrationTests(unittest.TestCase):
torch_device torch_device
) )
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") generated_ids = model.generate(
generated_ids = model.generate(input_features, do_sample=False, max_length=20) input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT) self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate( generated_ids = model.generate(
input_features, input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
do_sample=False,
max_length=20,
) )
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Kimura-san called me." EXPECTED_TRANSCRIPT = " Kimura-san called me."
self.assertEqual(transcript, EXPECTED_TRANSCRIPT) self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") generated_ids = model.generate(
generated_ids = model.generate(input_features, do_sample=False, max_length=20) input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san" EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
...@@ -1009,10 +1005,10 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -1009,10 +1005,10 @@ class WhisperModelIntegrationTests(unittest.TestCase):
# fmt: off # fmt: off
EXPECTED_LOGITS = torch.tensor( EXPECTED_LOGITS = torch.tensor(
[ [
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281], [50258, 50259, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257], [50258, 50259, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256], [50258, 50259, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11] [50258, 50259, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
] ]
) )
# fmt: on # fmt: on
...@@ -1021,10 +1017,10 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -1021,10 +1017,10 @@ class WhisperModelIntegrationTests(unittest.TestCase):
# fmt: off # fmt: off
EXPECTED_TRANSCRIPT = [ EXPECTED_TRANSCRIPT = [
" Mr. Quilter is the apostle of the middle classes and we are glad to", " Mr. Quilter is the apostle of the middle classes and we are glad",
" Nor is Mr. Quilter's manner less interesting than his matter.", " Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef", " He tells us that at this festive season of the year, with Christmas and roast",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,", " He has grave doubts whether Sir Frederick Layton's work is really Greek after all",
] ]
# fmt: on # fmt: on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment