Commit 34c0d115 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add Pretrained LM Support for Decoder (#2275)

Summary:
add function to download pretrained files for LibriSpeech 3-gram/4-gram KenLM, tests, and updated tutorial

Pull Request resolved: https://github.com/pytorch/audio/pull/2275

Reviewed By: mthrok

Differential Revision: D35115418

Pulled By: carolineechen

fbshipit-source-id: 83ff22380fce9c753bb4a7b7e3d89aa66c2831c0
parent 05592dff
......@@ -120,10 +120,9 @@ if sample_rate != bundle.sample_rate:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Next, we load in our token, lexicon, and KenLM data, which are used by
# the decoder to predict words from the acoustic model output.
#
# Note: this cell may take a couple of minutes to run, as the language
# model can be large
# the decoder to predict words from the acoustic model output. Pretrained
# files for the LibriSpeech dataset can be downloaded through torchaudio,
# or the user can provide their own files.
#
......@@ -169,10 +168,6 @@ print(tokens)
# ...
#
lexicon_url = "https://download.pytorch.org/torchaudio/tutorial-assets/ctc-decoding/lexicon-librispeech.txt"
lexicon_file = f"{hub_dir}/lexicon.txt"
torch.hub.download_url_to_file(lexicon_url, lexicon_file)
######################################################################
# KenLM
......@@ -187,9 +182,23 @@ torch.hub.download_url_to_file(lexicon_url, lexicon_file)
# `LibriSpeech <http://www.openslr.org/11>`__.
#
kenlm_url = "https://download.pytorch.org/torchaudio/tutorial-assets/ctc-decoding/4-gram-librispeech.bin"
kenlm_file = f"{hub_dir}/kenlm.bin"
torch.hub.download_url_to_file(kenlm_url, kenlm_file)
######################################################################
# Downloading Pretrained Files
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Pretrained files for the LibriSpeech dataset can be downloaded using
# :py:func:`download_pretrained_files <torchaudio.prototype.ctc_decoder.download_pretrained_files>`.
#
# Note: this cell may take a couple of minutes to run, as the language
# model can be large
#
from torchaudio.prototype.ctc_decoder import download_pretrained_files
files = download_pretrained_files("librispeech-4-gram")
print(files)
######################################################################
......@@ -218,9 +227,9 @@ LM_WEIGHT = 3.23
WORD_SCORE = -0.26
beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file,
tokens=tokens,
lm=kenlm_file,
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
nbest=3,
beam_size=1500,
lm_weight=LM_WEIGHT,
......@@ -285,8 +294,8 @@ emission, _ = acoustic_model(waveform)
#
greedy_result = greedy_decoder(emission[0])
greedy_transcript = greedy_result
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_transcript) / len(actual_transcript)
greedy_transcript = " ".join(greedy_result)
greedy_wer = torchaudio.functional.edit_distance(actual_transcript, greedy_result) / len(actual_transcript)
print(f"Transcript: {greedy_transcript}")
print(f"WER: {greedy_wer}")
......@@ -422,9 +431,9 @@ beam_sizes = [1, 5, 50, 500]
for beam_size in beam_sizes:
beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file,
tokens=tokens,
lm=kenlm_file,
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_size=beam_size,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
......@@ -448,9 +457,9 @@ beam_size_tokens = [1, 5, 10, num_tokens]
for beam_size_token in beam_size_tokens:
beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file,
tokens=tokens,
lm=kenlm_file,
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_size_token=beam_size_token,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
......@@ -475,9 +484,9 @@ beam_thresholds = [1, 5, 10, 25]
for beam_threshold in beam_thresholds:
beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file,
tokens=tokens,
lm=kenlm_file,
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
beam_threshold=beam_threshold,
lm_weight=LM_WEIGHT,
word_score=WORD_SCORE,
......@@ -501,9 +510,9 @@ lm_weights = [0, LM_WEIGHT, 15]
for lm_weight in lm_weights:
beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file,
tokens=tokens,
lm=kenlm_file,
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
lm_weight=lm_weight,
word_score=WORD_SCORE,
)
......
......@@ -73,3 +73,9 @@ def temp_hub_dir(tmpdir, pytestconfig):
torch.hub.set_dir(tmpdir)
yield
torch.hub.set_dir(org_dir)
@pytest.fixture()
def emissions():
path = torchaudio.utils.download_asset("test-assets/emissions-8555-28447-0012.pt")
return torch.load(path)
import pytest
@pytest.mark.parametrize(
"model,expected",
[
("librispeech", ["the", "captain", "shook", "his", "head"]),
("librispeech-3-gram", ["the", "captain", "shook", "his", "head"]),
],
)
def test_decoder_from_pretrained(model, expected, emissions):
from torchaudio.prototype.ctc_decoder import lexicon_decoder, download_pretrained_files
pretrained_files = download_pretrained_files(model)
decoder = lexicon_decoder(
lexicon=pretrained_files.lexicon,
tokens=pretrained_files.tokens,
lm=pretrained_files.lm,
)
result = decoder(emissions)
assert result[0][0].words == expected
......@@ -2,7 +2,7 @@ import torchaudio
try:
torchaudio._extension._load_lib("libtorchaudio_decoder")
from .ctc_decoder import Hypothesis, LexiconDecoder, lexicon_decoder
from .ctc_decoder import Hypothesis, LexiconDecoder, lexicon_decoder, download_pretrained_files
except ImportError as err:
raise ImportError(
"flashlight decoder bindings are required to use this functionality. "
......@@ -14,4 +14,5 @@ __all__ = [
"Hypothesis",
"LexiconDecoder",
"lexicon_decoder",
"download_pretrained_files",
]
import itertools as it
from collections import namedtuple
from typing import Dict, List, Optional, Union, NamedTuple
import torch
......@@ -15,11 +16,14 @@ from torchaudio._torchaudio_decoder import (
_load_words,
_ZeroLM,
)
from torchaudio.utils import download_asset
__all__ = ["Hypothesis", "LexiconDecoder", "lexicon_decoder"]
_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
class Hypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func`LexiconDecoder`.
......@@ -42,7 +46,8 @@ class LexiconDecoder:
Lexically contrained CTC beam search decoder from *Flashlight* [:footcite:`kahn2022flashlight`].
Note:
To build the decoder, please use the factory function :py:func:`lexicon_decoder`.
To build the decoder, please use factory function
:py:func:`lexicon_decoder`.
Args:
nbest (int): number of best decodings to return
......@@ -251,3 +256,48 @@ def lexicon_decoder(
sil_token=sil_token,
unk_word=unk_word,
)
def _get_filenames(model: str) -> _PretrainedFiles:
if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
raise ValueError(
f"{model} not supported. Must be one of ['librispeech-3-gram', 'librispeech-4-gram', 'librispeech']"
)
prefix = f"decoder-assets/{model}"
return _PretrainedFiles(
lexicon=f"{prefix}/lexicon.txt",
tokens=f"{prefix}/tokens.txt",
lm=f"{prefix}/lm.bin" if model != "librispeech" else None,
)
def download_pretrained_files(model: str) -> _PretrainedFiles:
"""
Retrieves pretrained data files used for CTC decoder.
Args:
model (str): pretrained language model to download
Options: ["librispeech-3-gram", "librispeech-4-gram", "librispeech"]
Returns:
Object with the following attributes:
lm: path corresponding to downloaded language model, or None if model is not
associated with an lm
lexicon: path corresponding to downloaded lexicon file
tokens: path corresponding to downloaded tokens file
"""
files = _get_filenames(model)
lexicon_file = download_asset(files.lexicon)
tokens_file = download_asset(files.tokens)
if files.lm is not None:
lm_file = download_asset(files.lm)
else:
lm_file = None
return _PretrainedFiles(
lexicon=lexicon_file,
tokens=tokens_file,
lm=lm_file,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment