"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d33d9f6715e0c2430c7b7d871b0e20a11601a05d"
Commit 34c0d115 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add Pretrained LM Support for Decoder (#2275)

Summary:
add function to download pretrained files for LibriSpeech 3-gram/4-gram KenLM, tests, and updated tutorial

Pull Request resolved: https://github.com/pytorch/audio/pull/2275

Reviewed By: mthrok

Differential Revision: D35115418

Pulled By: carolineechen

fbshipit-source-id: 83ff22380fce9c753bb4a7b7e3d89aa66c2831c0
parent 05592dff
...@@ -120,10 +120,9 @@ if sample_rate != bundle.sample_rate: ...@@ -120,10 +120,9 @@ if sample_rate != bundle.sample_rate:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# Next, we load in our token, lexicon, and KenLM data, which are used by # Next, we load in our token, lexicon, and KenLM data, which are used by
# the decoder to predict words from the acoustic model output. # the decoder to predict words from the acoustic model output. Pretrained
# # files for the LibriSpeech dataset can be downloaded through torchaudio,
# Note: this cell may take a couple of minutes to run, as the language # or the user can provide their own files.
# model can be large
# #
...@@ -169,10 +168,6 @@ print(tokens) ...@@ -169,10 +168,6 @@ print(tokens)
# ... # ...
# #
lexicon_url = "https://download.pytorch.org/torchaudio/tutorial-assets/ctc-decoding/lexicon-librispeech.txt"
lexicon_file = f"{hub_dir}/lexicon.txt"
torch.hub.download_url_to_file(lexicon_url, lexicon_file)
###################################################################### ######################################################################
# KenLM # KenLM
...@@ -187,9 +182,23 @@ torch.hub.download_url_to_file(lexicon_url, lexicon_file) ...@@ -187,9 +182,23 @@ torch.hub.download_url_to_file(lexicon_url, lexicon_file)
# `LibriSpeech <http://www.openslr.org/11>`__. # `LibriSpeech <http://www.openslr.org/11>`__.
# #
kenlm_url = "https://download.pytorch.org/torchaudio/tutorial-assets/ctc-decoding/4-gram-librispeech.bin"
kenlm_file = f"{hub_dir}/kenlm.bin" ######################################################################
torch.hub.download_url_to_file(kenlm_url, kenlm_file) # Downloading Pretrained Files
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Pretrained files for the LibriSpeech dataset can be downloaded using
# :py:func:`download_pretrained_files <torchaudio.prototype.ctc_decoder.download_pretrained_files>`.
#
# Note: this cell may take a couple of minutes to run, as the language
# model can be large
#
from torchaudio.prototype.ctc_decoder import download_pretrained_files
files = download_pretrained_files("librispeech-4-gram")
print(files)
###################################################################### ######################################################################
...@@ -218,9 +227,9 @@ LM_WEIGHT = 3.23 ...@@ -218,9 +227,9 @@ LM_WEIGHT = 3.23
WORD_SCORE = -0.26 WORD_SCORE = -0.26
beam_search_decoder = lexicon_decoder( beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file, lexicon=files.lexicon,
tokens=tokens, tokens=files.tokens,
lm=kenlm_file, lm=files.lm,
nbest=3, nbest=3,
beam_size=1500, beam_size=1500,
lm_weight=LM_WEIGHT, lm_weight=LM_WEIGHT,
...@@ -285,8 +294,8 @@ emission, _ = acoustic_model(waveform) ...@@ -285,8 +294,8 @@ emission, _ = acoustic_model(waveform)
# #
greedy_result = greedy_decoder(emission[0]) greedy_result = greedy_decoder(emission[0])
greedy_transcript = greedy_result greedy_transcript = " ".join(greedy_result)
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_transcript) / len(actual_transcript) greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)
print(f"Transcript: {greedy_transcript}") print(f"Transcript: {greedy_transcript}")
print(f"WER: {greedy_wer}") print(f"WER: {greedy_wer}")
...@@ -422,9 +431,9 @@ beam_sizes = [1, 5, 50, 500] ...@@ -422,9 +431,9 @@ beam_sizes = [1, 5, 50, 500]
for beam_size in beam_sizes: for beam_size in beam_sizes:
beam_search_decoder = lexicon_decoder( beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file, lexicon=files.lexicon,
tokens=tokens, tokens=files.tokens,
lm=kenlm_file, lm=files.lm,
beam_size=beam_size, beam_size=beam_size,
lm_weight=LM_WEIGHT, lm_weight=LM_WEIGHT,
word_score=WORD_SCORE, word_score=WORD_SCORE,
...@@ -448,9 +457,9 @@ beam_size_tokens = [1, 5, 10, num_tokens] ...@@ -448,9 +457,9 @@ beam_size_tokens = [1, 5, 10, num_tokens]
for beam_size_token in beam_size_tokens: for beam_size_token in beam_size_tokens:
beam_search_decoder = lexicon_decoder( beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file, lexicon=files.lexicon,
tokens=tokens, tokens=files.tokens,
lm=kenlm_file, lm=files.lm,
beam_size_token=beam_size_token, beam_size_token=beam_size_token,
lm_weight=LM_WEIGHT, lm_weight=LM_WEIGHT,
word_score=WORD_SCORE, word_score=WORD_SCORE,
...@@ -475,9 +484,9 @@ beam_thresholds = [1, 5, 10, 25] ...@@ -475,9 +484,9 @@ beam_thresholds = [1, 5, 10, 25]
for beam_threshold in beam_thresholds: for beam_threshold in beam_thresholds:
beam_search_decoder = lexicon_decoder( beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file, lexicon=files.lexicon,
tokens=tokens, tokens=files.tokens,
lm=kenlm_file, lm=files.lm,
beam_threshold=beam_threshold, beam_threshold=beam_threshold,
lm_weight=LM_WEIGHT, lm_weight=LM_WEIGHT,
word_score=WORD_SCORE, word_score=WORD_SCORE,
...@@ -501,9 +510,9 @@ lm_weights = [0, LM_WEIGHT, 15] ...@@ -501,9 +510,9 @@ lm_weights = [0, LM_WEIGHT, 15]
for lm_weight in lm_weights: for lm_weight in lm_weights:
beam_search_decoder = lexicon_decoder( beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file, lexicon=files.lexicon,
tokens=tokens, tokens=files.tokens,
lm=kenlm_file, lm=files.lm,
lm_weight=lm_weight, lm_weight=lm_weight,
word_score=WORD_SCORE, word_score=WORD_SCORE,
) )
......
...@@ -73,3 +73,9 @@ def temp_hub_dir(tmpdir, pytestconfig): ...@@ -73,3 +73,9 @@ def temp_hub_dir(tmpdir, pytestconfig):
torch.hub.set_dir(tmpdir) torch.hub.set_dir(tmpdir)
yield yield
torch.hub.set_dir(org_dir) torch.hub.set_dir(org_dir)
@pytest.fixture()
def emissions():
path = torchaudio.utils.download_asset("test-assets/emissions-8555-28447-0012.pt")
return torch.load(path)
import pytest
@pytest.mark.parametrize(
"model,expected",
[
("librispeech", ["the", "captain", "shook", "his", "head"]),
("librispeech-3-gram", ["the", "captain", "shook", "his", "head"]),
],
)
def test_decoder_from_pretrained(model, expected, emissions):
from torchaudio.prototype.ctc_decoder import lexicon_decoder, download_pretrained_files
pretrained_files = download_pretrained_files(model)
decoder = lexicon_decoder(
lexicon=pretrained_files.lexicon,
tokens=pretrained_files.tokens,
lm=pretrained_files.lm,
)
result = decoder(emissions)
assert result[0][0].words == expected
...@@ -2,7 +2,7 @@ import torchaudio ...@@ -2,7 +2,7 @@ import torchaudio
try: try:
torchaudio._extension._load_lib("libtorchaudio_decoder") torchaudio._extension._load_lib("libtorchaudio_decoder")
from .ctc_decoder import Hypothesis, LexiconDecoder, lexicon_decoder from .ctc_decoder import Hypothesis, LexiconDecoder, lexicon_decoder, download_pretrained_files
except ImportError as err: except ImportError as err:
raise ImportError( raise ImportError(
"flashlight decoder bindings are required to use this functionality. " "flashlight decoder bindings are required to use this functionality. "
...@@ -14,4 +14,5 @@ __all__ = [ ...@@ -14,4 +14,5 @@ __all__ = [
"Hypothesis", "Hypothesis",
"LexiconDecoder", "LexiconDecoder",
"lexicon_decoder", "lexicon_decoder",
"download_pretrained_files",
] ]
import itertools as it import itertools as it
from collections import namedtuple
from typing import Dict, List, Optional, Union, NamedTuple from typing import Dict, List, Optional, Union, NamedTuple
import torch import torch
...@@ -15,11 +16,14 @@ from torchaudio._torchaudio_decoder import ( ...@@ -15,11 +16,14 @@ from torchaudio._torchaudio_decoder import (
_load_words, _load_words,
_ZeroLM, _ZeroLM,
) )
from torchaudio.utils import download_asset
__all__ = ["Hypothesis", "LexiconDecoder", "lexicon_decoder"] __all__ = ["Hypothesis", "LexiconDecoder", "lexicon_decoder"]
_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
class Hypothesis(NamedTuple): class Hypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func`LexiconDecoder`. r"""Represents hypothesis generated by CTC beam search decoder :py:func`LexiconDecoder`.
...@@ -42,7 +46,8 @@ class LexiconDecoder: ...@@ -42,7 +46,8 @@ class LexiconDecoder:
Lexically contrained CTC beam search decoder from *Flashlight* [:footcite:`kahn2022flashlight`]. Lexically contrained CTC beam search decoder from *Flashlight* [:footcite:`kahn2022flashlight`].
Note: Note:
To build the decoder, please use the factory function :py:func:`lexicon_decoder`. To build the decoder, please use factory function
:py:func:`lexicon_decoder`.
Args: Args:
nbest (int): number of best decodings to return nbest (int): number of best decodings to return
...@@ -251,3 +256,48 @@ def lexicon_decoder( ...@@ -251,3 +256,48 @@ def lexicon_decoder(
sil_token=sil_token, sil_token=sil_token,
unk_word=unk_word, unk_word=unk_word,
) )
def _get_filenames(model: str) -> _PretrainedFiles:
if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
raise ValueError(
f"{model} not supported. Must be one of ['librispeech-3-gram', 'librispeech-4-gram', 'librispeech']"
)
prefix = f"decoder-assets/{model}"
return _PretrainedFiles(
lexicon=f"{prefix}/lexicon.txt",
tokens=f"{prefix}/tokens.txt",
lm=f"{prefix}/lm.bin" if model != "librispeech" else None,
)
def download_pretrained_files(model: str) -> _PretrainedFiles:
"""
Retrieves pretrained data files used for CTC decoder.
Args:
model (str): pretrained language model to download
Options: ["librispeech-3-gram", "librispeech-4-gram", "librispeech"]
Returns:
Object with the following attributes:
lm: path corresponding to downloaded language model, or None if model is not
associated with an lm
lexicon: path corresponding to downloaded lexicon file
tokens: path corresponding to downloaded tokens file
"""
files = _get_filenames(model)
lexicon_file = download_asset(files.lexicon)
tokens_file = download_asset(files.tokens)
if files.lm is not None:
lm_file = download_asset(files.lm)
else:
lm_file = None
return _PretrainedFiles(
lexicon=lexicon_file,
tokens=tokens_file,
lm=lm_file,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment