Unverified Commit e064f081 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add time stamps for wav2vec2 with lm (#15854)



* [Wav2Vec2 With LM] add timestamps

* correct

* correct

* Apply suggestions from code review

* correct

* Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py

* make style

* Update src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>

* make style

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 3f2e6368
...@@ -97,6 +97,8 @@ WAV2VEC2_KWARGS_DOCSTRING = r""" ...@@ -97,6 +97,8 @@ WAV2VEC2_KWARGS_DOCSTRING = r"""
Whether or not to print more information and warnings. Whether or not to print more information and warnings.
""" """
ListOfDict = List[Dict[str, Union[int, str]]]
@dataclass @dataclass
class Wav2Vec2CTCTokenizerOutput(ModelOutput): class Wav2Vec2CTCTokenizerOutput(ModelOutput):
...@@ -106,18 +108,18 @@ class Wav2Vec2CTCTokenizerOutput(ModelOutput): ...@@ -106,18 +108,18 @@ class Wav2Vec2CTCTokenizerOutput(ModelOutput):
Args: Args:
text (list of `str` or `str`): text (list of `str` or `str`):
Decoded logits in text from. Usually the speech transcription. Decoded logits in text from. Usually the speech transcription.
char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`): char_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
produced text. produced text.
word_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`): word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
can be used to compute time stamps for each word. can be used to compute time stamps for each word.
""" """
text: Union[List[str], str] text: Union[List[str], str]
char_offsets: List[Dict[str, Union[float, str]]] = None char_offsets: Union[List[ListOfDict], ListOfDict] = None
word_offsets: List[Dict[str, Union[float, str]]] = None word_offsets: Union[List[ListOfDict], ListOfDict] = None
class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
......
...@@ -66,6 +66,9 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -66,6 +66,9 @@ PRETRAINED_VOCAB_FILES_MAP = {
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"facebook/wav2vec2-lv-60-espeak-cv-ft": sys.maxsize} PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"facebook/wav2vec2-lv-60-espeak-cv-ft": sys.maxsize}
ListOfDict = List[Dict[str, Union[int, str]]]
@dataclass @dataclass
class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput): class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput):
""" """
...@@ -74,14 +77,14 @@ class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput): ...@@ -74,14 +77,14 @@ class Wav2Vec2PhonemeCTCTokenizerOutput(ModelOutput):
Args: Args:
text (list of `str` or `str`): text (list of `str` or `str`):
Decoded logits in text from. Usually the speech transcription. Decoded logits in text from. Usually the speech transcription.
char_offsets (`Dict[str, Union[int, str]]` or `Dict[str, Union[int, str]]`): char_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with
produced text. produced text.
""" """
text: Union[List[str], str] text: Union[List[str], str]
char_offsets: List[Dict[str, Union[float, str]]] = None char_offsets: Union[List[ListOfDict], ListOfDict] = None
class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer): class Wav2Vec2PhonemeCTCTokenizer(PreTrainedTokenizer):
......
...@@ -19,7 +19,7 @@ import os ...@@ -19,7 +19,7 @@ import os
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import get_context from multiprocessing import get_context
from typing import TYPE_CHECKING, Iterable, List, Optional, Union from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union
import numpy as np import numpy as np
...@@ -34,23 +34,30 @@ if TYPE_CHECKING: ...@@ -34,23 +34,30 @@ if TYPE_CHECKING:
from ...tokenization_utils import PreTrainedTokenizerBase from ...tokenization_utils import PreTrainedTokenizerBase
ListOfDict = List[Dict[str, Union[int, str]]]
@dataclass @dataclass
class Wav2Vec2DecoderWithLMOutput(ModelOutput): class Wav2Vec2DecoderWithLMOutput(ModelOutput):
""" """
Output type of [`Wav2Vec2DecoderWithLM`], with transcription. Output type of [`Wav2Vec2DecoderWithLM`], with transcription.
Args: Args:
text (list of `str`): text (list of `str` or `str`):
Decoded logits in text from. Usually the speech transcription. Decoded logits in text from. Usually the speech transcription.
logit_score (list of `float`): logit_score (list of `float` or `float`):
Total logit score of the beam associated with produced text. Total logit score of the beam associated with produced text.
lm_score (list of `float`): lm_score (list of `float`):
Fused lm_score of the beam associated with produced text. Fused lm_score of the beam associated with produced text.
word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`):
Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
can be used to compute time stamps for each word.
""" """
text: Union[List[str], str] text: Union[List[str], str]
logit_score: Union[List[float], float] = None logit_score: Union[List[float], float] = None
lm_score: Union[List[float], float] = None lm_score: Union[List[float], float] = None
word_offsets: Union[List[ListOfDict], ListOfDict] = None
class Wav2Vec2ProcessorWithLM(ProcessorMixin): class Wav2Vec2ProcessorWithLM(ProcessorMixin):
...@@ -232,6 +239,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -232,6 +239,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
beta: Optional[float] = None, beta: Optional[float] = None,
unk_score_offset: Optional[float] = None, unk_score_offset: Optional[float] = None,
lm_score_boundary: Optional[bool] = None, lm_score_boundary: Optional[bool] = None,
output_word_offsets: bool = False,
): ):
""" """
Batch decode output logits to audio transcription with language model support. Batch decode output logits to audio transcription with language model support.
...@@ -267,6 +275,18 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -267,6 +275,18 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
Amount of log score offset for unknown tokens Amount of log score offset for unknown tokens
lm_score_boundary (`bool`, *optional*): lm_score_boundary (`bool`, *optional*):
Whether to have kenlm respect boundaries when scoring Whether to have kenlm respect boundaries when scoring
output_word_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
and model downsampling rate to compute the time-stamps of transcribed words.
<Tip>
Please take a look at the Example of [`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
better understand how to make use of `output_word_offsets`.
[`~model.wav2vec2_with_lm.processing_wav2vec2_with_lm.batch_decode`] works the same way with batched
output.
</Tip>
Returns: Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`. [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
...@@ -310,13 +330,18 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -310,13 +330,18 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
pool.close() pool.close()
# extract text and scores # extract text and scores
batch_texts, logit_scores, lm_scores = [], [], [] batch_texts, logit_scores, lm_scores, word_offsets = [], [], [], []
for d in decoded_beams: for d in decoded_beams:
batch_texts.append(d[0][0]) batch_texts.append(d[0][0])
logit_scores.append(d[0][-2]) logit_scores.append(d[0][-2])
lm_scores.append(d[0][-1]) lm_scores.append(d[0][-1])
# more output features will be added in the future word_offsets.append([{"word": t[0], "start_offset": t[1][0], "end_offset": t[1][1]} for t in d[0][1]])
return Wav2Vec2DecoderWithLMOutput(text=batch_texts, logit_score=logit_scores, lm_score=lm_scores)
word_offsets = word_offsets if output_word_offsets else None
return Wav2Vec2DecoderWithLMOutput(
text=batch_texts, logit_score=logit_scores, lm_score=lm_scores, word_offsets=word_offsets
)
def decode( def decode(
self, self,
...@@ -330,6 +355,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -330,6 +355,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
beta: Optional[float] = None, beta: Optional[float] = None,
unk_score_offset: Optional[float] = None, unk_score_offset: Optional[float] = None,
lm_score_boundary: Optional[bool] = None, lm_score_boundary: Optional[bool] = None,
output_word_offsets: bool = False,
): ):
""" """
Decode output logits to audio transcription with language model support. Decode output logits to audio transcription with language model support.
...@@ -357,11 +383,65 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -357,11 +383,65 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
Amount of log score offset for unknown tokens Amount of log score offset for unknown tokens
lm_score_boundary (`bool`, *optional*): lm_score_boundary (`bool`, *optional*):
Whether to have kenlm respect boundaries when scoring Whether to have kenlm respect boundaries when scoring
output_word_offsets (`bool`, *optional*, defaults to `False`):
Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
and model downsampling rate to compute the time-stamps of transcribed words.
<Tip>
Please take a look at the example of [`~models.wav2vec2_with_lm.processing_wav2vec2_with_lm.decode`] to
better understand how to make use of `output_word_offsets`.
</Tip>
Returns: Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`. [`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
""" Example:
```python
>>> # Let's see how to retrieve time steps for a model
>>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
>>> from datasets import load_dataset
>>> import datasets
>>> import torch
>>> # import model, feature extractor, tokenizer
>>> model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
>>> processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
>>> # load first sample of English common_voice
>>> dataset = load_dataset("common_voice", "en", split="train", streaming=True)
>>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
>>> dataset_iter = iter(dataset)
>>> sample = next(dataset_iter)
>>> # forward sample through model to get greedily predicted transcription ids
>>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
>>> with torch.no_grad():
... logits = model(input_values).logits[0].cpu().numpy()
>>> # retrieve word stamps (analogous commands for `output_char_offsets`)
>>> outputs = tokenizer.decode(logits, output_word_offsets=True)
>>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
>>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate
>>> word_offsets = [
... {
... "word": d["word"],
... "start_time": d["start_offset"] * time_offset,
... "end_time": d["end_offset"] * time_offset,
... }
... for d in outputs.word_offsets
... ]
>>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer:
>>> # https://huggingface.co/datasets/common_voice/viewer/en/train
>>> word_offset
>>> # [{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES',
>>> # 'start_time': 1.64, 'end_time': 1.88}, {'word': 'A',
>>> # 'start_time': 2.12, 'end_time': 2.14}, {'word': 'MILE', 'start_time': 2.26, 'end_time': 2.46}, ...
```"""
from pyctcdecode.constants import ( from pyctcdecode.constants import (
DEFAULT_BEAM_WIDTH, DEFAULT_BEAM_WIDTH,
DEFAULT_HOTWORD_WEIGHT, DEFAULT_HOTWORD_WEIGHT,
...@@ -390,9 +470,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): ...@@ -390,9 +470,19 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
hotword_weight=hotword_weight, hotword_weight=hotword_weight,
) )
word_offsets = None
if output_word_offsets:
word_offsets = [
{"word": word, "start_offset": start_offset, "end_offset": end_offset}
for word, (start_offset, end_offset) in decoded_beams[0][2]
]
# more output features will be added in the future # more output features will be added in the future
return Wav2Vec2DecoderWithLMOutput( return Wav2Vec2DecoderWithLMOutput(
text=decoded_beams[0][0], logit_score=decoded_beams[0][-2], lm_score=decoded_beams[0][-1] text=decoded_beams[0][0],
logit_score=decoded_beams[0][-2],
lm_score=decoded_beams[0][-1],
word_offsets=word_offsets,
) )
@contextmanager @contextmanager
......
...@@ -20,13 +20,15 @@ import unittest ...@@ -20,13 +20,15 @@ import unittest
from multiprocessing import get_context from multiprocessing import get_context
from pathlib import Path from pathlib import Path
import datasets
import numpy as np import numpy as np
from datasets import load_dataset
from transformers import AutoProcessor from transformers import AutoProcessor
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available, is_torch_available
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
from transformers.testing_utils import require_pyctcdecode from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow
from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
...@@ -35,6 +37,10 @@ if is_pyctcdecode_available(): ...@@ -35,6 +37,10 @@ if is_pyctcdecode_available():
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from pyctcdecode import BeamSearchDecoderCTC from pyctcdecode import BeamSearchDecoderCTC
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm import Wav2Vec2DecoderWithLMOutput
if is_torch_available():
from transformers import Wav2Vec2ForCTC
@require_pyctcdecode @require_pyctcdecode
...@@ -350,3 +356,101 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): ...@@ -350,3 +356,101 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_auto = processor_auto.batch_decode(logits) decoded_auto = processor_auto.batch_decode(logits)
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text) self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
@staticmethod
def get_from_offsets(offsets, key):
retrieved_list = [d[key] for d in offsets]
return retrieved_list
def test_offsets_integration_fast(self):
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
logits = self._get_dummy_logits()[0]
outputs = processor.decode(logits, output_word_offsets=True)
# check Wav2Vec2CTCTokenizerOutput keys for word
self.assertTrue(len(outputs.keys()), 2)
self.assertTrue("text" in outputs)
self.assertTrue("word_offsets" in outputs)
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
self.assertEqual(" ".join(self.get_from_offsets(outputs["word_offsets"], "word")), outputs.text)
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "word"), ["<s>", "<s>", "</s>"])
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "start_offset"), [0, 2, 4])
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"], "end_offset"), [1, 3, 5])
def test_offsets_integration_fast_batch(self):
processor = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
logits = self._get_dummy_logits()
outputs = processor.batch_decode(logits, output_word_offsets=True)
# check Wav2Vec2CTCTokenizerOutput keys for word
self.assertTrue(len(outputs.keys()), 2)
self.assertTrue("text" in outputs)
self.assertTrue("word_offsets" in outputs)
self.assertTrue(isinstance(outputs, Wav2Vec2DecoderWithLMOutput))
self.assertListEqual(
[" ".join(self.get_from_offsets(o, "word")) for o in outputs["word_offsets"]], outputs.text
)
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "word"), ["<s>", "<s>", "</s>"])
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "start_offset"), [0, 2, 4])
self.assertListEqual(self.get_from_offsets(outputs["word_offsets"][0], "end_offset"), [1, 3, 5])
@slow
@require_torch
@require_torchaudio
def test_word_time_stamp_integration(self):
import torch
ds = load_dataset("common_voice", "en", split="train", streaming=True)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
ds_iter = iter(ds)
sample = next(ds_iter)
processor = AutoProcessor.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
# compare to filename `common_voice_en_100038.mp3` of dataset viewer on https://huggingface.co/datasets/common_voice/viewer/en/train
input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values).logits.cpu().numpy()
output = processor.decode(logits[0], output_word_offsets=True)
time_offset = model.config.inputs_to_logits_ratio / processor.feature_extractor.sampling_rate
word_time_stamps = [
{
"start_time": d["start_offset"] * time_offset,
"end_time": d["end_offset"] * time_offset,
"word": d["word"],
}
for d in output["word_offsets"]
]
EXPECTED_TEXT = "WHY DOES A MILE SANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL"
# output words
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), EXPECTED_TEXT)
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), output.text)
# output times
start_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "start_time")]
end_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "end_time")]
# fmt: off
self.assertListEqual(
start_times,
[
1.42, 1.64, 2.12, 2.26, 2.54, 3.0, 3.24, 3.6, 3.8, 4.1, 4.26, 4.94, 5.28, 5.66, 5.78, 5.94, 6.32, 6.54, 6.66,
],
)
self.assertListEqual(
end_times,
[
1.54, 1.88, 2.14, 2.46, 2.9, 3.18, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94,
],
)
# fmt: on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment