"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "ea6908769769511bc8dacf4b70929ffa32bb9da7"
Unverified Commit d718c0c3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2ProcessorWithLM] add alpha & beta to batch decode & decode (#15465)

parent 1d94d575
...@@ -253,6 +253,10 @@ class Wav2Vec2ProcessorWithLM: ...@@ -253,6 +253,10 @@ class Wav2Vec2ProcessorWithLM:
token_min_logp: Optional[float] = None, token_min_logp: Optional[float] = None,
hotwords: Optional[Iterable[str]] = None, hotwords: Optional[Iterable[str]] = None,
hotword_weight: Optional[float] = None, hotword_weight: Optional[float] = None,
alpha: Optional[float] = None,
beta: Optional[float] = None,
unk_score_offset: Optional[float] = None,
lm_score_boundary: Optional[bool] = None,
): ):
""" """
Batch decode output logits to audio transcription with language model support. Batch decode output logits to audio transcription with language model support.
...@@ -280,6 +284,14 @@ class Wav2Vec2ProcessorWithLM: ...@@ -280,6 +284,14 @@ class Wav2Vec2ProcessorWithLM:
List of words with extra importance, can be OOV for LM List of words with extra importance, can be OOV for LM
hotword_weight (`int`, *optional*): hotword_weight (`int`, *optional*):
Weight factor for hotword importance Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT. Weight factor for hotword importance Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
alpha (`float`, *optional*):
Weight for language model during shallow fusion
beta (`float`, *optional*):
Weight for length score adjustment of during scoring
unk_score_offset (`float`, *optional*):
Amount of log score offset for unknown tokens
lm_score_boundary (`bool`, *optional*):
Whether to have kenlm respect boundaries when scoring
Returns: Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`. [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
...@@ -298,6 +310,11 @@ class Wav2Vec2ProcessorWithLM: ...@@ -298,6 +310,11 @@ class Wav2Vec2ProcessorWithLM:
token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP
hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT
# reset params at every forward call. It's just a `set` method in pyctcdecode
self.decoder.reset_params(
alpha=alpha, beta=beta, unk_score_offset=unk_score_offset, lm_score_boundary=lm_score_boundary
)
# create multiprocessing pool and list numpy arrays # create multiprocessing pool and list numpy arrays
logits_list = [array for array in logits] logits_list = [array for array in logits]
pool = get_context("fork").Pool(num_processes) pool = get_context("fork").Pool(num_processes)
...@@ -330,6 +347,10 @@ class Wav2Vec2ProcessorWithLM: ...@@ -330,6 +347,10 @@ class Wav2Vec2ProcessorWithLM:
token_min_logp: Optional[float] = None, token_min_logp: Optional[float] = None,
hotwords: Optional[Iterable[str]] = None, hotwords: Optional[Iterable[str]] = None,
hotword_weight: Optional[float] = None, hotword_weight: Optional[float] = None,
alpha: Optional[float] = None,
beta: Optional[float] = None,
unk_score_offset: Optional[float] = None,
lm_score_boundary: Optional[bool] = None,
): ):
""" """
Decode output logits to audio transcription with language model support. Decode output logits to audio transcription with language model support.
...@@ -349,6 +370,14 @@ class Wav2Vec2ProcessorWithLM: ...@@ -349,6 +370,14 @@ class Wav2Vec2ProcessorWithLM:
List of words with extra importance which can be missing from the LM's vocabulary, e.g. ["huggingface"] List of words with extra importance which can be missing from the LM's vocabulary, e.g. ["huggingface"]
hotword_weight (`int`, *optional*): hotword_weight (`int`, *optional*):
Weight multiplier that boosts hotword scores. Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT. Weight multiplier that boosts hotword scores. Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
alpha (`float`, *optional*):
Weight for language model during shallow fusion
beta (`float`, *optional*):
Weight for length score adjustment of during scoring
unk_score_offset (`float`, *optional*):
Amount of log score offset for unknown tokens
lm_score_boundary (`bool`, *optional*):
Whether to have kenlm respect boundaries when scoring
Returns: Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`. [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
...@@ -367,6 +396,11 @@ class Wav2Vec2ProcessorWithLM: ...@@ -367,6 +396,11 @@ class Wav2Vec2ProcessorWithLM:
token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP token_min_logp = token_min_logp if token_min_logp is not None else DEFAULT_MIN_TOKEN_LOGP
hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT hotword_weight = hotword_weight if hotword_weight is not None else DEFAULT_HOTWORD_WEIGHT
# reset params at every forward call. It's just a `set` method in pyctcdecode
self.decoder.reset_params(
alpha=alpha, beta=beta, unk_score_offset=unk_score_offset, lm_score_boundary=lm_score_boundary
)
# pyctcdecode # pyctcdecode
decoded_beams = self.decoder.decode_beams( decoded_beams = self.decoder.decode_beams(
logits, logits,
......
...@@ -17,7 +17,7 @@ import os ...@@ -17,7 +17,7 @@ import os
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
from multiprocessing import Pool from multiprocessing import get_context
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -196,7 +196,9 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -196,7 +196,9 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_processor = processor.batch_decode(logits).text decoded_processor = processor.batch_decode(logits).text
logits_list = [array for array in logits] logits_list = [array for array in logits]
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(Pool(), logits_list)] pool = get_context("fork").Pool()
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)]
pool.close()
self.assertListEqual(decoded_decoder, decoded_processor) self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor) self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor)
...@@ -223,19 +225,68 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -223,19 +225,68 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_processor = decoded_processor_out.text decoded_processor = decoded_processor_out.text
logits_list = [array for array in logits] logits_list = [array for array in logits]
pool = get_context("fork").Pool()
decoded_decoder_out = decoder.decode_beams_batch( decoded_decoder_out = decoder.decode_beams_batch(
Pool(), pool,
logits_list, logits_list,
beam_width=beam_width, beam_width=beam_width,
beam_prune_logp=beam_prune_logp, beam_prune_logp=beam_prune_logp,
token_min_logp=token_min_logp, token_min_logp=token_min_logp,
) )
pool.close()
decoded_decoder = [d[0][0] for d in decoded_decoder_out] decoded_decoder = [d[0][0] for d in decoded_decoder_out]
self.assertListEqual(decoded_decoder, decoded_processor) self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> </s> </s>", "<s> <s> </s>"], decoded_processor) self.assertListEqual(["<s> </s> </s>", "<s> <s> </s>"], decoded_processor)
def test_decoder_with_params_of_lm(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
decoder = self.get_decoder()
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
logits = self._get_dummy_logits()
alpha = 2.0
beta = 5.0
unk_score_offset = -20.0
lm_score_boundary = True
decoded_processor_out = processor.batch_decode(
logits,
alpha=alpha,
beta=beta,
unk_score_offset=unk_score_offset,
lm_score_boundary=lm_score_boundary,
)
decoded_processor = decoded_processor_out.text
logits_list = [array for array in logits]
decoder.reset_params(
alpha=alpha,
beta=beta,
unk_score_offset=unk_score_offset,
lm_score_boundary=lm_score_boundary,
)
pool = get_context("fork").Pool()
decoded_decoder_out = decoder.decode_beams_batch(
pool,
logits_list,
)
pool.close()
decoded_decoder = [d[0][0] for d in decoded_decoder_out]
self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> </s> <s> </s> </s>", "</s> </s> <s> </s> </s>"], decoded_processor)
lm_model = processor.decoder.model_container[processor.decoder._model_key]
self.assertEqual(lm_model.alpha, 2.0)
self.assertEqual(lm_model.beta, 5.0)
self.assertEqual(lm_model.unk_score_offset, -20.0)
self.assertEqual(lm_model.score_boundary, True)
def test_decoder_download_ignores_files(self): def test_decoder_download_ignores_files(self):
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm") processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment