Unverified Commit 67047b86 authored by arampacha's avatar arampacha Committed by GitHub
Browse files

add scores to Wav2Vec2WithLMOutput (#15413)

* add scores to Wav2Vec2WithLMOutput

* style fixup
parent 45f56580
......@@ -42,9 +42,15 @@ class Wav2Vec2DecoderWithLMOutput(ModelOutput):
Args:
text (list of `str`):
Decoded logits in text from. Usually the speech transcription.
logit_score (list of `float`):
Total logit score of the beam associated with produced text.
lm_score (list of `float`):
Fused lm_score of the beam associated with produced text.
"""
text: Union[List[str], str]
logit_score: Union[List[float], float] = None
lm_score: Union[List[float], float] = None
class Wav2Vec2ProcessorWithLM(ProcessorMixin):
......@@ -283,7 +289,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
)
# create multiprocessing pool and list numpy arrays
logits_list = [array for array in logits]
# filter out logits padding
logits_list = [array[(array != -100.0).all(axis=-1)] for array in logits]
pool = get_context("fork").Pool(num_processes)
# pyctcdecode
......@@ -300,11 +307,14 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
# clone multi-processing pool
pool.close()
# extract text
batch_texts = [d[0][0] for d in decoded_beams]
# extract text and scores
batch_texts, logit_scores, lm_scores = [], [], []
for d in decoded_beams:
batch_texts.append(d[0][0])
logit_scores.append(d[0][-2])
lm_scores.append(d[0][-1])
# more output features will be added in the future
return Wav2Vec2DecoderWithLMOutput(text=batch_texts)
return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores)
def decode(
self,
......@@ -379,7 +389,9 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
)
# more output features will be added in the future
return Wav2Vec2DecoderWithLMOutput(text=decoded_beams[0][0])
return Wav2Vec2DecoderWithLMOutput(
text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1]
)
@contextmanager
def as_target_processor(self):
......
......@@ -178,12 +178,14 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits = self._get_dummy_logits(shape=(10, 16), seed=13)
decoded_processor = processor.decode(logits).text
decoded_processor = processor.decode(logits)
decoded_decoder = decoder.decode_beams(logits)[0][0]
decoded_decoder = decoder.decode_beams(logits)[0]
self.assertEqual(decoded_decoder, decoded_processor)
self.assertEqual("</s> <s> </s>", decoded_processor)
self.assertEqual(decoded_decoder[0], decoded_processor.text)
self.assertEqual("</s> <s> </s>", decoded_processor.text)
self.assertEqual(decoded_decoder[-2], decoded_processor.logit_score)
self.assertEqual(decoded_decoder[-1], decoded_processor.lm_score)
def test_decoder_batch(self):
feature_extractor = self.get_feature_extractor()
......@@ -194,15 +196,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits = self._get_dummy_logits()
decoded_processor = processor.batch_decode(logits).text
decoded_processor = processor.batch_decode(logits)
logits_list = [array for array in logits]
pool = get_context("fork").Pool()
decoded_decoder = [d[0][0] for d in decoder.decode_beams_batch(pool, logits_list)]
decoded_beams = decoder.decode_beams_batch(pool, logits_list)
texts_decoder, logit_scores_decoder, lm_scores_decoder = [], [], []
for beams in decoded_beams:
texts_decoder.append(beams[0][0])
logit_scores_decoder.append(beams[0][-2])
lm_scores_decoder.append(beams[0][-1])
pool.close()
self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor)
self.assertListEqual(texts_decoder, decoded_processor.text)
self.assertListEqual(["<s> <s> </s>", "<s> <s> <s>"], decoded_processor.text)
self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score)
self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score)
def test_decoder_with_params(self):
feature_extractor = self.get_feature_extractor()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment