Commit 66185e00 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Use pretrained LM API for decoder example (#2317)

Summary:
update example ASR pipeline to use the recently added pretrained LM API for decoding

Pull Request resolved: https://github.com/pytorch/audio/pull/2317

Reviewed By: mthrok

Differential Revision: D35361354

Pulled By: carolineechen

fbshipit-source-id: cac7cf55bd9f86417f319191c1405819fe2a7b46
parent 4a749e2d
...@@ -4,37 +4,24 @@ from typing import Optional ...@@ -4,37 +4,24 @@ from typing import Optional
import torch import torch
import torchaudio import torchaudio
from torchaudio.prototype.ctc_decoder import lexicon_decoder from torchaudio.prototype.ctc_decoder import lexicon_decoder, download_pretrained_files
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _download_files(lexicon_file, kenlm_file):
torch.hub.download_url_to_file(
"https://pytorch.s3.amazonaws.com/torchaudio/tutorial-assets/ctc-decoding/lexicon-librispeech.txt", lexicon_file
)
torch.hub.download_url_to_file(
"https://pytorch.s3.amazonaws.com/torchaudio/tutorial-assets/ctc-decoding/4-gram-librispeech.bin", kenlm_file
)
def run_inference(args): def run_inference(args):
# get pretrained wav2vec2.0 model # get pretrained wav2vec2.0 model
bundle = getattr(torchaudio.pipelines, args.model) bundle = getattr(torchaudio.pipelines, args.model)
model = bundle.get_model() model = bundle.get_model()
tokens = [label.lower() for label in bundle.get_labels()]
# get decoder files # get decoder files
hub_dir = torch.hub.get_dir() files = download_pretrained_files("librispeech-4-gram")
lexicon_file = f"{hub_dir}/lexicon.txt"
kenlm_file = f"{hub_dir}/kenlm.bin"
_download_files(lexicon_file, kenlm_file)
decoder = lexicon_decoder( decoder = lexicon_decoder(
lexicon=lexicon_file, lexicon=files.lexicon,
tokens=tokens, tokens=files.tokens,
lm=kenlm_file, lm=files.lm,
nbest=args.nbest, nbest=args.nbest,
beam_size=args.beam_size, beam_size=args.beam_size,
beam_size_token=args.beam_size_token, beam_size_token=args.beam_size_token,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment