"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "8d602649d008fc20e793a26a66a5aac8a79fb02a"
Unverified Commit 6e57a569 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding timestamps for CTC with LM in ASR pipeline. (#15863)

* Adding timestamps for CTC with LM in ASR pipeline.

* iRemove print.

* Nit change.
parent 8a133490
...@@ -353,7 +353,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): ...@@ -353,7 +353,7 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
word = char word = char
last_state = state last_state = state
if state == "WORD": if last_state == "WORD":
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
return word_offsets return word_offsets
......
...@@ -313,8 +313,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -313,8 +313,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# Optional return types # Optional return types
optional = {} optional = {}
if return_timestamps and self.type != "ctc": if return_timestamps and self.type == "seq2seq":
raise ValueError("We cannot return_timestamps yet on non-ctc models !") raise ValueError("We cannot return_timestamps yet on non-ctc models !")
if return_timestamps == "char" and self.type == "ctc_with_lm":
raise ValueError("CTC with LM cannot return `char` timestamps, only `words`")
final_items = [] final_items = []
key = "logits" if self.type == "ctc_with_lm" else "tokens" key = "logits" if self.type == "ctc_with_lm" else "tokens"
...@@ -335,34 +337,43 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -335,34 +337,43 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if self.type == "ctc_with_lm": if self.type == "ctc_with_lm":
if decoder_kwargs is None: if decoder_kwargs is None:
decoder_kwargs = {} decoder_kwargs = {}
text = self.decoder.decode_beams(items, **decoder_kwargs)[0][0] beams = self.decoder.decode_beams(items, **decoder_kwargs)
text = beams[0][0]
if return_timestamps:
# Simply cast from pyctcdecode format to wav2vec2 format to leverage
# pre-existing code later
chunk_offset = beams[0][2]
word_offsets = []
for word, (start_offset, end_offset) in chunk_offset:
word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
else: else:
skip_special_tokens = self.type != "ctc" skip_special_tokens = self.type != "ctc"
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens) text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
if return_timestamps: if return_timestamps:
if return_timestamps == "char": char_offsets = self.tokenizer.decode(
decoded = self.tokenizer.decode( items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True )["char_offsets"]
) if return_timestamps == "word":
elif return_timestamps == "word": word_offsets = self.tokenizer._get_word_offsets(
decoded = self.tokenizer.decode( char_offsets, self.tokenizer.replace_word_delimiter_char
items, skip_special_tokens=skip_special_tokens, output_word_offsets=True
)
chunks = []
for item in decoded[f"{return_timestamps}_offsets"]:
start = (
item["start_offset"]
* self.model.config.inputs_to_logits_ratio
/ self.feature_extractor.sampling_rate
)
stop = (
item["end_offset"]
* self.model.config.inputs_to_logits_ratio
/ self.feature_extractor.sampling_rate
) )
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
optional["chunks"] = chunks if return_timestamps:
if return_timestamps == "word":
offsets = word_offsets
else:
offsets = char_offsets
chunks = []
for item in offsets:
start = item["start_offset"] * self.model.config.inputs_to_logits_ratio
start /= self.feature_extractor.sampling_rate
stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio
stop /= self.feature_extractor.sampling_rate
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
optional["chunks"] = chunks
extra = defaultdict(list) extra = defaultdict(list)
for output in model_outputs: for output in model_outputs:
......
...@@ -188,6 +188,32 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel ...@@ -188,6 +188,32 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
}, },
) )
speech_recognizer.type = "ctc_with_lm"
# Simple test with CTC with LM, chunking + timestamps
output = speech_recognizer(filename, chunk_length_s=2.0, return_timestamps="word")
self.assertEqual(
output,
{
"text": "y en las ramas medio sumergidas revoloteaban algunos pájaros de quimérico y legendario plumajcri",
"chunks": [
{"text": "y", "timestamp": (0.52, 0.54)},
{"text": "en", "timestamp": (0.6, 0.68)},
{"text": "las", "timestamp": (0.74, 0.84)},
{"text": "ramas", "timestamp": (0.94, 1.24)},
{"text": "medio", "timestamp": (1.32, 1.52)},
{"text": "sumergidas", "timestamp": (1.56, 2.22)},
{"text": "revoloteaban", "timestamp": (2.36, 3.0)},
{"text": "algunos", "timestamp": (3.06, 3.38)},
{"text": "pájaros", "timestamp": (3.46, 3.86)},
{"text": "de", "timestamp": (3.92, 4.0)},
{"text": "quimérico", "timestamp": (4.08, 4.6)},
{"text": "y", "timestamp": (4.66, 4.68)},
{"text": "legendario", "timestamp": (4.74, 5.26)},
{"text": "plumajcri", "timestamp": (5.34, 5.74)},
],
},
)
@require_tf @require_tf
def test_small_model_tf(self): def test_small_model_tf(self):
self.skipTest("Tensorflow not supported yet.") self.skipTest("Tensorflow not supported yet.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment