import json import math from collections import namedtuple from typing import List, Tuple import sentencepiece as spm import torch import torchaudio from torchaudio.models import Hypothesis MODEL_TYPE_LIBRISPEECH = "librispeech" MODEL_TYPE_TEDLIUM3 = "tedlium3" DECIBEL = 2 * 20 * math.log10(torch.iinfo(torch.int16).max) GAIN = pow(10, 0.05 * DECIBEL) spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160) Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"]) def piecewise_linear_log(x): x = x * GAIN x[x > math.e] = torch.log(x[x > math.e]) x[x <= math.e] = x[x <= math.e] / math.e return x def batch_by_token_count(idx_target_lengths, token_limit): batches = [] current_batch = [] current_token_count = 0 for idx, target_length in idx_target_lengths: if current_token_count + target_length > token_limit: batches.append(current_batch) current_batch = [idx] current_token_count = target_length else: current_batch.append(idx) current_token_count += target_length if current_batch: batches.append(current_batch) return batches def post_process_hypos( hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor ) -> List[Tuple[str, float, List[int], List[int]]]: post_process_remove_list = [ sp_model.unk_id(), sp_model.eos_id(), sp_model.pad_id(), ] filtered_hypo_tokens = [ [token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos ] hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens] hypos_ali = [h.alignment[1:] for h in hypos] hypos_ids = [h.tokens[1:] for h in hypos] hypos_score = [[math.exp(h.score)] for h in hypos] nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids)) return nbest_batch class FunctionalModule(torch.nn.Module): def __init__(self, functional): super().__init__() self.functional = functional def forward(self, input): return self.functional(input) class GlobalStatsNormalization(torch.nn.Module): def __init__(self, global_stats_path): super().__init__() with open(global_stats_path) as f: blob = json.loads(f.read()) self.mean = torch.tensor(blob["mean"]) self.invstddev = torch.tensor(blob["invstddev"]) def forward(self, input): return (input - self.mean) * self.invstddev class WarmupLR(torch.optim.lr_scheduler._LRScheduler): def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False): self.warmup_updates = warmup_updates super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) def get_lr(self): return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs]