import argparse import logging from typing import Dict, List import torch import torch.nn.functional as F import torchaudio from torchaudio.models.decoder import ctc_decoder, CTCDecoder, download_pretrained_files from utils import _get_id2label logger = logging.getLogger(__name__) def _load_checkpoint(checkpoint: str) -> torch.nn.Module: model = torchaudio.models.hubert_base(aux_num_out=29) checkpoint = torch.load(checkpoint, map_location="cpu") state_dict = checkpoint["state_dict"] new_state_dict = {} for k in state_dict: if "model.wav2vec2" in k: new_state_dict[k.replace("model.wav2vec2.", "")] = state_dict[k] elif "aux" in k: new_state_dict[k] = state_dict[k] model.load_state_dict(new_state_dict) return model def _viterbi_decode(emission: torch.Tensor, id2token: Dict, blank_idx: int = 0) -> List[str]: """Run greedy decoding for ctc outputs. Args: emission (torch.Tensor): Output of CTC layer. Tensor with dimensions (..., time, num_tokens). id2token (Dictionary): The dictionary that maps indices of emission's last dimension to the corresponding tokens. Returns: (List of str): The decoding result. List of string in lower case. """ hypothesis = emission.argmax(-1).unique_consecutive() hypothesis = hypothesis[hypothesis != blank_idx] hypothesis = "".join(id2token[int(i)] for i in hypothesis).replace("|", " ").strip() return hypothesis.split() def _ctc_decode(emission, decoder: CTCDecoder) -> List[str]: """Run CTC decoding with a KenLM language model. Args: emission (torch.Tensor): Output of CTC layer. Tensor with dimensions `(..., time, num_tokens)`. decoder (CTCDecoder): The initialized CTCDecoder. Returns: (List of str): The decoding result. List of string in lower case. """ hypothesis = decoder(emission) hypothesis = hypothesis[0][0].words hypothesis = [word for word in hypothesis if word != " "] return hypothesis def run_inference(args): if args.use_gpu: device = torch.device("cuda") else: device = torch.device("cpu") # Load the fine-tuned HuBERTPretrainModel from checkpoint. model = _load_checkpoint(args.checkpoint) model.eval().to(device) if args.use_lm: # get decoder files files = download_pretrained_files("librispeech-4-gram") decoder = ctc_decoder( lexicon=files.lexicon, tokens=files.tokens, lm=files.lm, nbest=args.nbest, beam_size=args.beam_size, beam_size_token=args.beam_size_token, beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, word_score=args.word_score, unk_score=args.unk_score, sil_score=args.sil_score, log_add=False, ) else: id2token = _get_id2label() dataset = torchaudio.datasets.LIBRISPEECH(args.librispeech_path, url=args.split) total_edit_distance = 0 total_length = 0 for idx, sample in enumerate(dataset): waveform, _, transcript, _, _, _ = sample transcript = transcript.strip().lower().strip().replace("\n", "") with torch.inference_mode(): emission, _ = model(waveform.to(device)) emission = F.log_softmax(emission, dim=-1) if args.use_lm: hypothesis = _ctc_decode(emission.cpu(), decoder) else: hypothesis = _viterbi_decode(emission, id2token) total_edit_distance += torchaudio.functional.edit_distance(hypothesis, transcript.split()) total_length += len(transcript.split()) if idx % 100 == 0: logger.info(f"Processed elem {idx}; WER: {total_edit_distance / total_length}") logger.info(f"Final WER: {total_edit_distance / total_length}") def _parse_args(): parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument( "--librispeech-path", type=str, help="Folder where LibriSpeech dataset is stored.", ) parser.add_argument( "--split", type=str, choices=["dev-clean", "dev-other", "test-clean", "test-other"], help="LibriSpeech dataset split. (Default: 'test-clean')", default="test-clean", ) parser.add_argument( "--checkpoint", type=str, help="The checkpoint path of fine-tuned HuBERTPretrainModel.", ) parser.add_argument("--use-lm", action="store_true", help="Whether to use language model for decoding.") parser.add_argument("--nbest", type=int, default=1, help="Number of best hypotheses to return.") parser.add_argument( "--beam-size", type=int, default=1500, help="Beam size for determining number of hypotheses to store. (Default: 1500)", ) parser.add_argument( "--beam-size-token", type=int, default=29, help="Number of tokens to consider at each beam search step. (Default: 29)", ) parser.add_argument( "--beam-threshold", type=int, default=100, help="Beam threshold for pruning hypotheses. (Default: 100)" ) parser.add_argument( "--lm-weight", type=float, default=2.46, help="Languge model weight in decoding. (Default: 2.46)", ) parser.add_argument( "--word-score", type=float, default=-0.59, help="Word insertion score in decoding. (Default: -0.59)", ) parser.add_argument( "--unk-score", type=float, default=float("-inf"), help="Unknown word insertion score. (Default: -inf)" ) parser.add_argument("--sil-score", type=float, default=0, help="Silence insertion score. (Default: 0)") parser.add_argument("--use-gpu", action="store_true", help="Whether to use GPU for decoding.") parser.add_argument("--debug", action="store_true", help="Whether to use debug level for logging.") return parser.parse_args() def _init_logger(debug): fmt = "%(asctime)s %(message)s" if debug else "%(message)s" level = logging.DEBUG if debug else logging.INFO logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S") def _main(): args = _parse_args() _init_logger(args.debug) run_inference(args) if __name__ == "__main__": _main()