Unverified Commit 03966cac authored by Vladislav Sokolovskii's avatar Vladislav Sokolovskii Committed by GitHub
Browse files

Wav2Vec2ProcessorWithLM can return N best hypotheses now (#22235)



* Wav2Vec2ProcessorWithLM can return N best hypotheses now
Signed-off-by: default avatarVladislav Sokolovskii <vladislav@parrothq.com>

* Wav2Vec2ProcessorWithLM n_best cannot be None
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Batch decoding can return  N best hypotheses now

batch_decode was extended with the same functionality as decode
function, N best hypotheses per sample can be returned
Signed-off-by: default avatarVladislav Sokolovskii <vladislav@parrothq.com>

---------
Signed-off-by: default avatarVladislav Sokolovskii <vladislav@parrothq.com>
Co-authored-by: default avatarVladislav Sokolovskii <vladislav@parrothq.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 66d1eee6
...@@ -50,18 +50,18 @@ class Wav2Vec2DecoderWithLMOutput(ModelOutput): ...@@ -50,18 +50,18 @@ class Wav2Vec2DecoderWithLMOutput(ModelOutput):
text (list of `str` or `str`): text (list of `str` or `str`):
Decoded logits in text from. Usually the speech transcription. Decoded logits in text from. Usually the speech transcription.
logit_score (list of `float` or `float`): logit_score (list of `float` or `float`):
Total logit score of the beam associated with produced text. Total logit score of the beams associated with produced text.
lm_score (list of `float`): lm_score (list of `float`):
Fused lm_score of the beam associated with produced text. Fused lm_score of the beams associated with produced text.
word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`): word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
can be used to compute time stamps for each word. can be used to compute time stamps for each word.
""" """
text: Union[List[str], str] text: Union[List[List[str]], List[str], str]
logit_score: Union[List[float], float] = None logit_score: Union[List[List[float]], List[float], float] = None
lm_score: Union[List[float], float] = None lm_score: Union[List[List[float]], List[float], float] = None
word_offsets: Union[List[ListOfDict], ListOfDict] = None word_offsets: Union[List[List[ListOfDict]], List[ListOfDict], ListOfDict] = None
class Wav2Vec2ProcessorWithLM(ProcessorMixin): class Wav2Vec2ProcessorWithLM(ProcessorMixin):
...@@ -296,6 +296,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -296,6 +296,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
unk_score_offset: Optional[float] = None, unk_score_offset: Optional[float] = None,
lm_score_boundary: Optional[bool] = None, lm_score_boundary: Optional[bool] = None,
output_word_offsets: bool = False, output_word_offsets: bool = False,
n_best: int = 1,
): ):
""" """
Batch decode output logits to audio transcription with language model support. Batch decode output logits to audio transcription with language model support.
...@@ -350,6 +351,11 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -350,6 +351,11 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
output_word_offsets (`bool`, *optional*, defaults to `False`): output_word_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
and model downsampling rate to compute the time-stamps of transcribed words. and model downsampling rate to compute the time-stamps of transcribed words.
n_best (`int`, *optional*, defaults to `1`):
Number of best hypotheses to return. If `n_best` is greater than 1, the returned `text` will be a list
of lists of strings, `logit_score` will be a list of lists of floats, and `lm_score` will be a list of
lists of floats, where the length of the outer list will correspond to the batch size and the length of
the inner list will correspond to the number of returned hypotheses . The value should be >= 1.
<Tip> <Tip>
...@@ -425,17 +431,40 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -425,17 +431,40 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
# extract text and scores # extract text and scores
batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], [] batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], []
for d in decoded_beams: for d in decoded_beams:
batch_texts.append(d[0][0]) batch_texts.append([beam[0] for beam in d])
logit_scores.append(d[0][-2]) logit_scores.append([beam[-2] for beam in d])
lm_scores.append(d[0][-1]) lm_scores.append([beam[-1] for beam in d])
word_offsets.append([{"word": t[0], "start_offset": t[1][0], "end_offset": t[1][1]} for t in d[0][1]])
# word_offsets.append([{"word": t[0], "start_offset": t[1][0], "end_offset": t[1][1]} for t in d[0][1]])
word_offsets.append(
[
[
{"word": word, "start_offset": start_offset, "end_offset": end_offset}
for word, (start_offset, end_offset) in beam[1]
]
for beam in d
]
)
word_offsets = word_offsets if output_word_offsets else None word_offsets = word_offsets if output_word_offsets else None
return Wav2Vec2DecoderWithLMOutput( if n_best == 1:
text=batch_texts, logit_score=logit_scores, lm_score=lm_scores, word_offsets=word_offsets return Wav2Vec2DecoderWithLMOutput(
) text=[hyps[0] for hyps in batch_texts],
logit_score=[hyps[0] for hyps in logit_scores],
lm_score=[hyps[0] for hyps in lm_scores],
word_offsets=[hyps[0] for hyps in word_offsets] if word_offsets is not None else None,
)
else:
return Wav2Vec2DecoderWithLMOutput(
text=[hyps[:n_best] for hyps in batch_texts],
logit_score=[hyps[:n_best] for hyps in logit_scores],
lm_score=[hyps[:n_best] for hyps in lm_scores],
word_offsets=[hyps[:n_best] for hyps in word_offsets] if word_offsets is not None else None,
)
def decode( def decode(
self, self,
...@@ -450,6 +479,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -450,6 +479,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
unk_score_offset: Optional[float] = None, unk_score_offset: Optional[float] = None,
lm_score_boundary: Optional[bool] = None, lm_score_boundary: Optional[bool] = None,
output_word_offsets: bool = False, output_word_offsets: bool = False,
n_best: int = 1,
): ):
""" """
Decode output logits to audio transcription with language model support. Decode output logits to audio transcription with language model support.
...@@ -480,6 +510,10 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -480,6 +510,10 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
output_word_offsets (`bool`, *optional*, defaults to `False`): output_word_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
and model downsampling rate to compute the time-stamps of transcribed words. and model downsampling rate to compute the time-stamps of transcribed words.
n_best (`int`, *optional*, defaults to `1`):
Number of best hypotheses to return. If `n_best` is greater than 1, the returned `text` will be a list
of strings, `logit_score` will be a list of floats, and `lm_score` will be a list of floats, where the
length of these lists will correspond to the number of returned hypotheses. The value should be >= 1.
<Tip> <Tip>
...@@ -564,17 +598,37 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -564,17 +598,37 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
word_offsets = None word_offsets = None
if output_word_offsets: if output_word_offsets:
word_offsets = [ word_offsets = [
{"word": word, "start_offset": start_offset, "end_offset": end_offset} [
for word, (start_offset, end_offset) in decoded_beams[0][2] {"word": word, "start_offset": start_offset, "end_offset": end_offset}
for word, (start_offset, end_offset) in beam[2]
]
for beam in decoded_beams
] ]
logit_scores = [beam[-2] for beam in decoded_beams]
# more output features will be added in the future lm_scores = [beam[-1] for beam in decoded_beams]
return Wav2Vec2DecoderWithLMOutput(
text=decoded_beams[0][0], hypotheses = [beam[0] for beam in decoded_beams]
logit_score=decoded_beams[0][-2],
lm_score=decoded_beams[0][-1], if n_best > len(decoded_beams):
word_offsets=word_offsets, logger.info(
) "N-best size is larger than the number of generated hypotheses, all hypotheses will be returned."
)
if n_best == 1:
return Wav2Vec2DecoderWithLMOutput(
text=hypotheses[0],
logit_score=logit_scores[0],
lm_score=lm_scores[0],
word_offsets=word_offsets[0] if word_offsets is not None else None,
)
else:
return Wav2Vec2DecoderWithLMOutput(
text=hypotheses[:n_best],
logit_score=logit_scores[:n_best],
lm_score=lm_scores[:n_best],
word_offsets=word_offsets[:n_best] if word_offsets is not None else None,
)
@contextmanager @contextmanager
def as_target_processor(self): def as_target_processor(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment