Unverified Commit ad0d7d17 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding the option to return_timestamps on pure CTC ASR models. (#15792)



* Adding the option to return_timestamps on pure CTC ASR models.

* Remove `math.prod` which was introduced in Python 3.8

* int are not floats.

* Reworking the PR to support "char" vs "word" output.

* Fixup!

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Quality.
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7566734d
......@@ -14,7 +14,8 @@
# limitations under the License.
""" Hubert model configuration"""
import math
import functools
import operator
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -253,4 +254,4 @@ class HubertConfig(PretrainedConfig):
@property
def inputs_to_logits_ratio(self):
return math.prod(self.conv_stride)
return functools.reduce(operator.mul, self.conv_stride, 1)
......@@ -14,7 +14,8 @@
# limitations under the License.
""" SEW model configuration"""
import math
import functools
import operator
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -248,4 +249,4 @@ class SEWConfig(PretrainedConfig):
@property
def inputs_to_logits_ratio(self):
return math.prod(self.conv_stride)
return functools.reduce(operator.mul, self.conv_stride, 1)
......@@ -14,7 +14,8 @@
# limitations under the License.
""" SEW-D model configuration"""
import math
import functools
import operator
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -284,4 +285,4 @@ class SEWDConfig(PretrainedConfig):
@property
def inputs_to_logits_ratio(self):
return math.prod(self.conv_stride)
return functools.reduce(operator.mul, self.conv_stride, 1)
......@@ -14,7 +14,8 @@
# limitations under the License.
""" UniSpeech model configuration"""
import math
import functools
import operator
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -294,4 +295,4 @@ class UniSpeechConfig(PretrainedConfig):
@property
def inputs_to_logits_ratio(self):
return math.prod(self.conv_stride)
return functools.reduce(operator.mul, self.conv_stride, 1)
......@@ -14,7 +14,8 @@
# limitations under the License.
""" UniSpeechSat model configuration"""
import math
import functools
import operator
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -311,4 +312,4 @@ class UniSpeechSatConfig(PretrainedConfig):
@property
def inputs_to_logits_ratio(self):
return math.prod(self.conv_stride)
return functools.reduce(operator.mul, self.conv_stride, 1)
......@@ -14,7 +14,8 @@
# limitations under the License.
""" Wav2Vec2 model configuration"""
import math
import functools
import operator
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -334,4 +335,4 @@ class Wav2Vec2Config(PretrainedConfig):
@property
def inputs_to_logits_ratio(self):
return math.prod(self.conv_stride)
return functools.reduce(operator.mul, self.conv_stride, 1)
......@@ -14,7 +14,8 @@
# limitations under the License.
""" WavLM model configuration"""
import math
import functools
import operator
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -335,4 +336,4 @@ class WavLMConfig(PretrainedConfig):
@property
def inputs_to_logits_ratio(self):
return math.prod(self.conv_stride)
return functools.reduce(operator.mul, self.conv_stride, 1)
......@@ -165,10 +165,23 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
inference to provide more context to the model). Only use `stride` with CTC models.
return_timestamps (*optional*, `str`):
Only available for pure CTC models. If set to `"char"`, the pipeline will return `timestamps` along the
text for every character in the text. For instance if you get `[{"text": "h", "timestamps": (0.5,0.6),
{"text": "i", "timestamps": (0.7, .9)}]`, then it means the model predicts that the letter "h" was
pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return
`timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ",
"timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model
predicts that the word "hi" was pronounces before 0.5 and after 0.9 seconds.
Return:
`Dict`: A dictionary with the following keys:
- **text** (`str`) -- The recognized text.
- **text** (`str` ) -- The recognized text.
- **chunks** (*optional(, `List[Dict]`)
When using `return_timestamps`, the `chunks` will become a list containing all the various text
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
`"".join(chunk["text"] for chunk in output["chunks"])`.
"""
return super().__call__(inputs, **kwargs)
......@@ -183,6 +196,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
postprocess_params = {}
if "decoder_kwargs" in kwargs:
postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"]
if "return_timestamps" in kwargs:
postprocess_params["return_timestamps"] = kwargs["return_timestamps"]
return preprocess_params, {}, postprocess_params
......@@ -323,7 +338,13 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
extra = model_inputs
return {"is_last": is_last, **out, **extra}
def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None):
def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None):
# Optional return types
optional = {}
if return_timestamps and self.type != "ctc":
raise ValueError("We cannot return_timestamps yet on non-ctc models !")
if self.type == "ctc_with_lm":
final_logits = []
for outputs in model_outputs:
......@@ -349,6 +370,30 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
tokens = tokens.squeeze(0)
text = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
if return_timestamps:
if return_timestamps == "char":
decoded = self.tokenizer.decode(
tokens, skip_special_tokens=skip_special_tokens, output_char_offsets=True
)
elif return_timestamps == "word":
decoded = self.tokenizer.decode(
tokens, skip_special_tokens=skip_special_tokens, output_word_offsets=True
)
chunks = []
for item in decoded[f"{return_timestamps}_offsets"]:
start = (
item["start_offset"]
* self.model.config.inputs_to_logits_ratio
/ self.feature_extractor.sampling_rate
)
stop = (
item["end_offset"]
* self.model.config.inputs_to_logits_ratio
/ self.feature_extractor.sampling_rate
)
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
optional["chunks"] = chunks
extra = defaultdict(list)
for output in model_outputs:
output.pop("tokens", None)
......@@ -357,4 +402,4 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if k == "is_last":
continue
extra[k].append(v)
return {"text": text, **extra}
return {"text": text, **optional, **extra}
......@@ -82,15 +82,46 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
# Striding
audio = {"raw": audio, "stride": (0, 4000), "sampling_rate": speech_recognizer.feature_extractor.sampling_rate}
if speech_recognizer.type == "ctc":
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
else:
# Non CTC models cannot use striding.
with self.assertRaises(ValueError):
outputs = speech_recognizer(audio)
# Timestamps
audio = np.zeros((34000,))
if speech_recognizer.type == "ctc":
outputs = speech_recognizer(audio, return_timestamps="char")
self.assertIsInstance(outputs["chunks"], list)
n = len(outputs["chunks"])
self.assertEqual(
outputs,
{
"text": ANY(str),
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(n)],
},
)
outputs = speech_recognizer(audio, return_timestamps="word")
self.assertIsInstance(outputs["chunks"], list)
n = len(outputs["chunks"])
self.assertEqual(
outputs,
{
"text": ANY(str),
"chunks": [{"text": ANY(str), "timestamp": (ANY(float), ANY(float))} for i in range(n)],
},
)
else:
# Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError):
outputs = speech_recognizer(audio, return_timestamps="char")
@require_torch
@slow
def test_pt_defaults(self):
......@@ -302,6 +333,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
......@@ -322,6 +354,49 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "ZBT ZC")
@require_torch
def test_return_timestamps_ctc_fast(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="hf-internal-testing/tiny-random-wav2vec2",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
# Take short audio to keep the test readable
audio = ds[40]["audio"]["array"][:800]
output = speech_recognizer(audio, return_timestamps="char")
self.assertEqual(
output,
{
"text": "ZBT ZX G",
"chunks": [
{"text": " ", "timestamp": (0.0, 0.012)},
{"text": "Z", "timestamp": (0.012, 0.016)},
{"text": "B", "timestamp": (0.016, 0.02)},
{"text": "T", "timestamp": (0.02, 0.024)},
{"text": " ", "timestamp": (0.024, 0.028)},
{"text": "Z", "timestamp": (0.028, 0.032)},
{"text": "X", "timestamp": (0.032, 0.036)},
{"text": " ", "timestamp": (0.036, 0.04)},
{"text": "G", "timestamp": (0.04, 0.044)},
],
},
)
output = speech_recognizer(audio, return_timestamps="word")
self.assertEqual(
output,
{
"text": "ZBT ZX G",
"chunks": [
{"text": "ZBT", "timestamp": (0.012, 0.024)},
{"text": "ZX", "timestamp": (0.028, 0.036)},
{"text": "G", "timestamp": (0.04, 0.044)},
],
},
)
@require_torch
@require_pyctcdecode
def test_chunking_fast_with_lm(self):
......@@ -399,7 +474,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
@require_torch
@slow
def test_chunking(self):
def test_chunking_and_timestamps(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
......@@ -416,11 +491,79 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
audio = ds[40]["audio"]["array"]
n_repeats = 10
audio = np.tile(audio, n_repeats)
output = speech_recognizer([audio], batch_size=2)
expected_text = "A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats
expected = [{"text": expected_text.strip()}]
self.assertEqual(output, expected)
audio_tiled = np.tile(audio, n_repeats)
output = speech_recognizer([audio_tiled], batch_size=2)
self.assertEqual(output, [{"text": ("A MAN SAID TO THE UNIVERSE SIR I EXIST " * n_repeats).strip()}])
output = speech_recognizer(audio, return_timestamps="char")
self.assertEqual(audio.shape, (74_400,))
self.assertEqual(speech_recognizer.feature_extractor.sampling_rate, 16_000)
# The audio is 74_400 / 16_000 = 4.65s long.
self.assertEqual(
output,
{
"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST",
"chunks": [
{"text": "A", "timestamp": (0.6, 0.62)},
{"text": " ", "timestamp": (0.62, 0.66)},
{"text": "M", "timestamp": (0.68, 0.7)},
{"text": "A", "timestamp": (0.78, 0.8)},
{"text": "N", "timestamp": (0.84, 0.86)},
{"text": " ", "timestamp": (0.92, 0.98)},
{"text": "S", "timestamp": (1.06, 1.08)},
{"text": "A", "timestamp": (1.14, 1.16)},
{"text": "I", "timestamp": (1.16, 1.18)},
{"text": "D", "timestamp": (1.2, 1.24)},
{"text": " ", "timestamp": (1.24, 1.28)},
{"text": "T", "timestamp": (1.28, 1.32)},
{"text": "O", "timestamp": (1.34, 1.36)},
{"text": " ", "timestamp": (1.38, 1.42)},
{"text": "T", "timestamp": (1.42, 1.44)},
{"text": "H", "timestamp": (1.44, 1.46)},
{"text": "E", "timestamp": (1.46, 1.5)},
{"text": " ", "timestamp": (1.5, 1.56)},
{"text": "U", "timestamp": (1.58, 1.62)},
{"text": "N", "timestamp": (1.64, 1.68)},
{"text": "I", "timestamp": (1.7, 1.72)},
{"text": "V", "timestamp": (1.76, 1.78)},
{"text": "E", "timestamp": (1.84, 1.86)},
{"text": "R", "timestamp": (1.86, 1.9)},
{"text": "S", "timestamp": (1.96, 1.98)},
{"text": "E", "timestamp": (1.98, 2.02)},
{"text": " ", "timestamp": (2.02, 2.06)},
{"text": "S", "timestamp": (2.82, 2.86)},
{"text": "I", "timestamp": (2.94, 2.96)},
{"text": "R", "timestamp": (2.98, 3.02)},
{"text": " ", "timestamp": (3.06, 3.12)},
{"text": "I", "timestamp": (3.5, 3.52)},
{"text": " ", "timestamp": (3.58, 3.6)},
{"text": "E", "timestamp": (3.66, 3.68)},
{"text": "X", "timestamp": (3.68, 3.7)},
{"text": "I", "timestamp": (3.9, 3.92)},
{"text": "S", "timestamp": (3.94, 3.96)},
{"text": "T", "timestamp": (4.0, 4.02)},
{"text": " ", "timestamp": (4.06, 4.1)},
],
},
)
output = speech_recognizer(audio, return_timestamps="word")
self.assertEqual(
output,
{
"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST",
"chunks": [
{"text": "A", "timestamp": (0.6, 0.62)},
{"text": "MAN", "timestamp": (0.68, 0.86)},
{"text": "SAID", "timestamp": (1.06, 1.24)},
{"text": "TO", "timestamp": (1.28, 1.36)},
{"text": "THE", "timestamp": (1.42, 1.5)},
{"text": "UNIVERSE", "timestamp": (1.58, 2.02)},
{"text": "SIR", "timestamp": (2.82, 3.02)},
{"text": "I", "timestamp": (3.5, 3.52)},
{"text": "EXIST", "timestamp": (3.66, 4.02)},
],
},
)
@require_torch
@slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment